mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-19 02:40:15 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
fd503d8a96
@ -9,7 +9,7 @@ import zipfile
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypedDict
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
@ -135,12 +135,13 @@ class FrontendManager:
|
|||||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Initializes the frontend for the specified version.
|
Initializes the frontend for the specified version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version_string (str): The version string.
|
version_string (str): The version string.
|
||||||
|
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the initialized frontend.
|
str: The path to the initialized frontend.
|
||||||
@ -153,7 +154,7 @@ class FrontendManager:
|
|||||||
return cls.DEFAULT_FRONTEND_PATH
|
return cls.DEFAULT_FRONTEND_PATH
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
provider = FrontEndProvider(repo_owner, repo_name)
|
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||||
release = provider.get_release(version)
|
release = provider.get_release(version)
|
||||||
|
|
||||||
semantic_version = release["tag_name"].lstrip("v")
|
semantic_version = release["tag_name"].lstrip("v")
|
||||||
@ -161,15 +162,21 @@ class FrontendManager:
|
|||||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||||
)
|
)
|
||||||
if not os.path.exists(web_root):
|
if not os.path.exists(web_root):
|
||||||
os.makedirs(web_root, exist_ok=True)
|
try:
|
||||||
logging.info(
|
os.makedirs(web_root, exist_ok=True)
|
||||||
"Downloading frontend(%s) version(%s) to (%s)",
|
logging.info(
|
||||||
provider.folder_name,
|
"Downloading frontend(%s) version(%s) to (%s)",
|
||||||
semantic_version,
|
provider.folder_name,
|
||||||
web_root,
|
semantic_version,
|
||||||
)
|
web_root,
|
||||||
logging.debug(release)
|
)
|
||||||
download_release_asset_zip(release, destination_path=web_root)
|
logging.debug(release)
|
||||||
|
download_release_asset_zip(release, destination_path=web_root)
|
||||||
|
finally:
|
||||||
|
# Clean up the directory if it is empty, i.e. the download failed
|
||||||
|
if not os.listdir(web_root):
|
||||||
|
os.rmdir(web_root)
|
||||||
|
|
||||||
return web_root
|
return web_root
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -603,7 +603,9 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
@routes.post("/internal/models/download")
|
@routes.post("/internal/models/download")
|
||||||
async def download_handler(request):
|
async def download_handler(request):
|
||||||
async def report_progress(filename: str, status: DownloadModelStatus):
|
async def report_progress(filename: str, status: DownloadModelStatus):
|
||||||
await self.send_json("download_progress", status.to_dict())
|
payload = status.to_dict()
|
||||||
|
payload['download_path'] = filename
|
||||||
|
await self.send_json("download_progress", payload)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
|
|||||||
@ -32,7 +32,7 @@ from . import utils
|
|||||||
from .cldm import cldm, mmdit
|
from .cldm import cldm, mmdit
|
||||||
from .ldm import hydit
|
from .ldm import hydit
|
||||||
from .ldm.cascade import controlnet as cascade_controlnet
|
from .ldm.cascade import controlnet as cascade_controlnet
|
||||||
from .ldm.flux import controlnet_xlabs
|
from .ldm.flux import controlnet as controlnet_flux
|
||||||
from .ldm.flux.controlnet_instantx import InstantXControlNetFlux
|
from .ldm.flux.controlnet_instantx import InstantXControlNetFlux
|
||||||
from .ldm.flux.controlnet_instantx_format2 import InstantXControlNetFluxFormat2
|
from .ldm.flux.controlnet_instantx_format2 import InstantXControlNetFluxFormat2
|
||||||
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||||
@ -152,7 +152,7 @@ class ControlBase:
|
|||||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||||
x *= (self.strength ** float(len(control_output) - i))
|
x *= (self.strength ** float(len(control_output) - i))
|
||||||
|
|
||||||
if x.dtype != output_dtype:
|
if output_dtype is not None and x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
out[key].append(x)
|
out[key].append(x)
|
||||||
@ -211,7 +211,6 @@ class ControlNet(ControlBase):
|
|||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
@ -241,7 +240,7 @@ class ControlNet(ControlBase):
|
|||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype=None)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
@ -441,7 +440,7 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_flux_instantx(sd, controlnet_class, weight_dtype, full_path):
|
def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full_path):
|
||||||
keys_to_keep = [
|
keys_to_keep = [
|
||||||
"controlnet_",
|
"controlnet_",
|
||||||
"single_transformer_blocks",
|
"single_transformer_blocks",
|
||||||
@ -498,19 +497,32 @@ def load_controlnet_flux_instantx(sd, controlnet_class, weight_dtype, full_path)
|
|||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs(sd):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||||
control_model = controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = controlnet_flux.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_flux_instantx(sd):
|
||||||
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
|
control_model = controlnet_flux.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
|
latent_format = latent_formats.Flux()
|
||||||
|
extra_conds = ['y', 'guidance']
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
if "controlnet_mode_embedder.weight" in controlnet_data:
|
if "controlnet_mode_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path)
|
return load_controlnet_flux_instantx_union(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path)
|
||||||
if "controlnet_mode_embedder.fc.weight" in controlnet_data:
|
if "controlnet_mode_embedder.fc.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path)
|
return load_controlnet_flux_instantx_union(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path)
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
|
||||||
return load_controlnet_hunyuandit(controlnet_data)
|
return load_controlnet_hunyuandit(controlnet_data)
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
@ -572,8 +584,10 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
|||||||
elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs(controlnet_data)
|
||||||
else:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
return load_controlnet_mmdit(controlnet_data)
|
||||||
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
|
return load_controlnet_flux_instantx(controlnet_data)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||||
@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
def rms_norm(x, weight, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|||||||
140
comfy/ldm/flux/controlnet.py
Normal file
140
comfy/ldm/flux/controlnet.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||||
|
MLPEmbedder, SingleStreamBlock,
|
||||||
|
timestep_embedding)
|
||||||
|
|
||||||
|
from .model import Flux
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFlux(Flux):
|
||||||
|
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
self.main_model_double = 19
|
||||||
|
self.main_model_single = 38
|
||||||
|
# add ControlNet blocks
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth_single_blocks):
|
||||||
|
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.latent_input = latent_input
|
||||||
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
if not self.latent_input:
|
||||||
|
self.input_hint_block = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
if not self.latent_input:
|
||||||
|
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||||
|
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
|
||||||
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
|
img = img + controlnet_cond
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
controlnet_double = ()
|
||||||
|
|
||||||
|
for i in range(len(self.double_blocks)):
|
||||||
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
controlnet_single = ()
|
||||||
|
|
||||||
|
for i in range(len(self.single_blocks)):
|
||||||
|
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||||
|
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||||
|
if self.latent_input:
|
||||||
|
out_input = ()
|
||||||
|
for x in controlnet_double:
|
||||||
|
out_input += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_input = (controlnet_double * repeat)
|
||||||
|
|
||||||
|
out = {"input": out_input[:self.main_model_double]}
|
||||||
|
if len(controlnet_single) > 0:
|
||||||
|
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
||||||
|
out_output = ()
|
||||||
|
if self.latent_input:
|
||||||
|
for x in controlnet_single:
|
||||||
|
out_output += (x,) * repeat
|
||||||
|
else:
|
||||||
|
out_output = (controlnet_single * repeat)
|
||||||
|
out["output"] = out_output[:self.main_model_single]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||||
|
patch_size = 2
|
||||||
|
if self.latent_input:
|
||||||
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
else:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
||||||
@ -1,104 +0,0 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor, nn
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
||||||
MLPEmbedder, SingleStreamBlock,
|
|
||||||
timestep_embedding)
|
|
||||||
|
|
||||||
from .model import Flux
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
||||||
|
|
||||||
# add ControlNet blocks
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(self.params.depth):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
# controlnet_block = zero_module(controlnet_block)
|
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
self.input_hint_block = nn.Sequential(
|
|
||||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_orig(
|
|
||||||
self,
|
|
||||||
img: Tensor,
|
|
||||||
img_ids: Tensor,
|
|
||||||
controlnet_cond: Tensor,
|
|
||||||
txt: Tensor,
|
|
||||||
txt_ids: Tensor,
|
|
||||||
timesteps: Tensor,
|
|
||||||
y: Tensor,
|
|
||||||
guidance: Tensor = None,
|
|
||||||
) -> Tensor:
|
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
# running on sequences img
|
|
||||||
img = self.img_in(img)
|
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
|
||||||
img = img + controlnet_cond
|
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
||||||
if self.params.guidance_embed:
|
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
||||||
vec = vec + self.vector_in(y)
|
|
||||||
txt = self.txt_in(txt)
|
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
pe = self.pe_embedder(ids)
|
|
||||||
|
|
||||||
block_res_samples = ()
|
|
||||||
|
|
||||||
for block in self.double_blocks:
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
||||||
block_res_samples = block_res_samples + (img,)
|
|
||||||
|
|
||||||
controlnet_block_res_samples = ()
|
|
||||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
|
||||||
block_res_sample = controlnet_block(block_res_sample)
|
|
||||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
|
||||||
|
|
||||||
return {"input": (controlnet_block_res_samples * 10)[:19]}
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
|
||||||
hint = hint * 2.0 - 1.0
|
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
|
||||||
patch_size = 2
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
||||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
|
||||||
@ -5,7 +5,7 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
from ... import ops
|
from ..common_dit import rms_norm
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
@ -63,10 +63,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
x_dtype = x.dtype
|
return rms_norm(x, self.scale, 1e-6)
|
||||||
x = x.float()
|
|
||||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
|
||||||
return (x * rrms).to(dtype=x_dtype) * ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
|
|||||||
@ -356,29 +356,9 @@ class RMSNorm(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
"""
|
|
||||||
Apply the RMSNorm normalization to the input tensor.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The normalized tensor.
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||||
Forward pass through the RMSNorm layer.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): The input tensor.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The output tensor after applying RMSNorm.
|
|
||||||
"""
|
|
||||||
x = self._norm(x)
|
|
||||||
if self.learnable_scale:
|
|
||||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLUFeedForward(nn.Module):
|
class SwiGLUFeedForward(nn.Module):
|
||||||
|
|||||||
149
comfy/lora.py
149
comfy/lora.py
@ -15,12 +15,15 @@
|
|||||||
You should have received a copy of the GNU General Public License
|
You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from . import utils
|
|
||||||
from . import model_base
|
from . import model_base
|
||||||
from . import model_management
|
from . import model_management
|
||||||
|
from . import utils
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
LORA_CLIP_MAP = {
|
||||||
"mlp.fc1": "mlp_fc1",
|
"mlp.fc1": "mlp_fc1",
|
||||||
@ -71,7 +74,7 @@ def load_lora(lora, to_load):
|
|||||||
B_name = "{}.lora.down.weight".format(x)
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
elif transformers_lora in lora.keys():
|
elif transformers_lora in lora.keys():
|
||||||
A_name = transformers_lora
|
A_name = transformers_lora
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
B_name = "{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
|
||||||
if A_name is not None:
|
if A_name is not None:
|
||||||
mid = None
|
mid = None
|
||||||
@ -82,7 +85,6 @@ def load_lora(lora, to_load):
|
|||||||
loaded_keys.add(A_name)
|
loaded_keys.add(A_name)
|
||||||
loaded_keys.add(B_name)
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
|
|
||||||
######## loha
|
######## loha
|
||||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
@ -105,7 +107,6 @@ def load_lora(lora, to_load):
|
|||||||
loaded_keys.add(hada_w2_a_name)
|
loaded_keys.add(hada_w2_a_name)
|
||||||
loaded_keys.add(hada_w2_b_name)
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
|
||||||
|
|
||||||
######## lokr
|
######## lokr
|
||||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
@ -153,7 +154,7 @@ def load_lora(lora, to_load):
|
|||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
|
||||||
|
|
||||||
#glora
|
# glora
|
||||||
a1_name = "{}.a1.weight".format(x)
|
a1_name = "{}.a1.weight".format(x)
|
||||||
a2_name = "{}.a2.weight".format(x)
|
a2_name = "{}.a2.weight".format(x)
|
||||||
b1_name = "{}.b1.weight".format(x)
|
b1_name = "{}.b1.weight".format(x)
|
||||||
@ -195,12 +196,13 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
for b in range(32): #TODO: clean up
|
for b in range(32): # TODO: clean up
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
@ -208,58 +210,58 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) # diffusers lora
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
clip_l_present = True
|
clip_l_present = True
|
||||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) # diffusers lora
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
if clip_l_present:
|
if clip_l_present:
|
||||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # SDXL base
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) # diffusers lora
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
else:
|
else:
|
||||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # TODO: test if this is correct for SDXL-Refiner
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) # diffusers lora
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # cascade lora: TODO put lora key prefix in the model config
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
if k.startswith("t5xxl.transformer."): # OneTrainer SD3 lora
|
||||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
elif k.startswith("hydit_clip.transformer.bert."): # HunyuanDiT Lora
|
||||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
|
||||||
k = "clip_g.transformer.text_projection.weight"
|
k = "clip_g.transformer.text_projection.weight"
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
key_map["lora_prior_te_text_projection"] = k # cascade lora?
|
||||||
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
||||||
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
key_map["lora_te2_text_projection"] = k # OneTrainer SD3 lora
|
||||||
|
|
||||||
k = "clip_l.transformer.text_projection.weight"
|
k = "clip_l.transformer.text_projection.weight"
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
key_map["lora_te1_text_projection"] = k # OneTrainer SD3 lora, not necessary but omits warning
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def model_lora_keys_unet(model, key_map={}):
|
def model_lora_keys_unet(model, key_map={}):
|
||||||
sd = model.state_dict()
|
sd = model.state_dict()
|
||||||
sdk = sd.keys()
|
sdk = sd.keys()
|
||||||
@ -268,8 +270,8 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = k
|
key_map["lora_unet_{}".format(key_lora)] = k
|
||||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora: TODO put lora key prefix in the model config
|
||||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
key_map["{}".format(k[:-len(".weight")])] = k # generic lora format without any weird key names
|
||||||
|
|
||||||
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
@ -285,41 +287,41 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||||
key_map[diffusers_lora_key] = unet_key
|
key_map[diffusers_lora_key] = unet_key
|
||||||
|
|
||||||
if isinstance(model, model_base.SD3): #Diffusers lora SD3
|
if isinstance(model, model_base.SD3): # Diffusers lora SD3
|
||||||
diffusers_keys = utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
diffusers_keys = utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) # regular diffusers sd3 lora format
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # format for flash-sd3 lora and others?
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
if isinstance(model, model_base.AuraFlow): #Diffusers lora AuraFlow
|
if isinstance(model, model_base.AuraFlow): # Diffusers lora AuraFlow
|
||||||
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
if isinstance(model, model_base.HunyuanDiT):
|
if isinstance(model, model_base.HunyuanDiT):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
|
key_map["base_model.model.{}".format(key_lora)] = k # official hunyuan lora format
|
||||||
|
|
||||||
if isinstance(model, model_base.Flux): #Diffusers lora Flux
|
if isinstance(model, model_base.Flux): # Diffusers lora Flux
|
||||||
diffusers_keys = utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
diffusers_keys = utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
@ -344,6 +346,41 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
|
|||||||
weight[:] = weight_calc
|
weight[:] = weight_calc
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
@ -363,7 +400,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
|
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype),)
|
||||||
|
|
||||||
patch_type = ""
|
patch_type = ""
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
@ -373,13 +410,19 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
v = v[1]
|
v = v[1]
|
||||||
|
|
||||||
if patch_type == "diff":
|
if patch_type == "diff":
|
||||||
w1 = v[0]
|
diff: torch.Tensor = v[0]
|
||||||
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||||
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||||
|
if do_pad_weight and diff.shape != weight.shape:
|
||||||
|
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||||
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||||
|
|
||||||
if strength != 0.0:
|
if strength != 0.0:
|
||||||
if w1.shape != weight.shape:
|
if diff.shape != weight.shape:
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += function(strength * model_management.cast_to_device(w1, weight.device, weight.dtype))
|
weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": # lora/locon
|
||||||
mat1 = model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
mat1 = model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
mat2 = model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
mat2 = model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
dora_scale = v[4]
|
dora_scale = v[4]
|
||||||
@ -389,7 +432,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
alpha = 1.0
|
alpha = 1.0
|
||||||
|
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
mat3 = model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
mat3 = model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
@ -415,7 +458,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
if w1 is None:
|
if w1 is None:
|
||||||
dim = w1_b.shape[0]
|
dim = w1_b.shape[0]
|
||||||
w1 = torch.mm(model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
w1 = torch.mm(model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
w1 = model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
w1 = model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
@ -423,12 +466,12 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
dim = w2_b.shape[0]
|
dim = w2_b.shape[0]
|
||||||
if t2 is None:
|
if t2 is None:
|
||||||
w2 = torch.mm(model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
w2 = torch.mm(model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
w2 = model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
w2 = model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
@ -458,23 +501,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
w2a = v[3]
|
w2a = v[3]
|
||||||
w2b = v[4]
|
w2b = v[4]
|
||||||
dora_scale = v[7]
|
dora_scale = v[7]
|
||||||
if v[5] is not None: #cp decomposition
|
if v[5] is not None: # cp decomposition
|
||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
m1 = torch.mm(model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
m1 = torch.mm(model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
m2 = torch.mm(model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
m2 = torch.mm(model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
|||||||
@ -463,6 +463,8 @@ def _unload_model_clones(model, unload_weights_only=True, force_unload=True) ->
|
|||||||
if not force_unload:
|
if not force_unload:
|
||||||
if unload_weights_only and unload_weight == False:
|
if unload_weights_only and unload_weight == False:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
|
unload_weight = True
|
||||||
|
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||||
|
|||||||
@ -566,6 +566,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
||||||
|
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import argparse
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from comfy.app.frontend_management import (
|
from comfy.app.frontend_management import (
|
||||||
FrontendManager,
|
FrontendManager,
|
||||||
@ -84,6 +85,35 @@ def test_init_frontend_invalid_provider():
|
|||||||
with pytest.raises(HTTPError):
|
with pytest.raises(HTTPError):
|
||||||
FrontendManager.init_frontend_unsafe(version_string)
|
FrontendManager.init_frontend_unsafe(version_string)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_os_functions():
|
||||||
|
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
|
||||||
|
patch('app.frontend_management.os.listdir') as mock_listdir, \
|
||||||
|
patch('app.frontend_management.os.rmdir') as mock_rmdir:
|
||||||
|
mock_listdir.return_value = [] # Simulate empty directory
|
||||||
|
yield mock_makedirs, mock_listdir, mock_rmdir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_download():
|
||||||
|
with patch('app.frontend_management.download_release_asset_zip') as mock:
|
||||||
|
mock.side_effect = Exception("Download failed") # Simulate download failure
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def test_finally_block(mock_os_functions, mock_download, mock_provider):
|
||||||
|
# Arrange
|
||||||
|
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
|
||||||
|
version_string = 'test-owner/test-repo@1.0.0'
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_makedirs.assert_called_once()
|
||||||
|
mock_download.assert_called_once()
|
||||||
|
mock_listdir.assert_called_once()
|
||||||
|
mock_rmdir.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_parse_version_string():
|
def test_parse_version_string():
|
||||||
version_string = "owner/repo@1.0.0"
|
version_string = "owner/repo@1.0.0"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user