mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 08:22:36 +08:00
Merge branch 'master' into feat/api-nodes/tencent-UV-unwrap
This commit is contained in:
commit
5a5a6abd98
@ -227,7 +227,7 @@ Put your VAE in: models/vae
|
|||||||
|
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||||
|
|
||||||
|
|||||||
@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
|||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
super().cleanup()
|
super().cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class QwenFunControlNet(ControlNet):
|
||||||
|
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||||
|
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||||
|
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||||
|
# unchanged while >1 grows more gently.
|
||||||
|
original_strength = self.strength
|
||||||
|
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||||
|
try:
|
||||||
|
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||||
|
finally:
|
||||||
|
self.strength = original_strength
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
self.set_extra_arg("base_model", model.diffusion_model)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
|
c.control_model = self.control_model
|
||||||
|
c.control_model_wrapped = self.control_model_wrapped
|
||||||
|
self.copy_to(c)
|
||||||
|
return c
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
|||||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
sd = model_config.process_unet_state_dict(sd)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
in_features = sd["control_img_in.weight"].shape[1]
|
||||||
|
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||||
|
|
||||||
|
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||||
|
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||||
|
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||||
|
|
||||||
|
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||||
|
control_in_features=in_features,
|
||||||
|
inner_dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
operations=operations,
|
||||||
|
device=comfy.model_management.unet_offload_device(),
|
||||||
|
dtype=unet_dtype,
|
||||||
|
)
|
||||||
|
model = controlnet_load_state_dict(model, sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
|
control = QwenFunControlNet(
|
||||||
|
model,
|
||||||
|
compression_ratio=1,
|
||||||
|
latent_format=latent_format,
|
||||||
|
# Fun checkpoints already expect their own 33-channel context handling.
|
||||||
|
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||||
|
# and breaks the intended fallback packing path.
|
||||||
|
concat_mask=False,
|
||||||
|
load_device=load_device,
|
||||||
|
manual_cast_dtype=manual_cast_dtype,
|
||||||
|
extra_conds=[],
|
||||||
|
)
|
||||||
|
return control
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||||
|
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
RMSNorm,
|
|
||||||
ModulationOut,
|
ModulationOut,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -4,8 +4,6 @@ from functools import lru_cache
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import RMSNorm
|
|
||||||
|
|
||||||
|
|
||||||
class NerfEmbedder(nn.Module):
|
class NerfEmbedder(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
# We now need to generate parameters for 3 matrices.
|
# We now need to generate parameters for 3 matrices.
|
||||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
|
||||||
@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
class NerfFinalLayer(nn.Module):
|
class NerfFinalLayer(nn.Module):
|
||||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
|||||||
class NerfFinalLayerConv(nn.Module):
|
class NerfFinalLayerConv(nn.Module):
|
||||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.conv = operations.Conv2d(
|
self.conv = operations.Conv2d(
|
||||||
in_channels=hidden_size,
|
in_channels=hidden_size,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
|
|||||||
@ -5,9 +5,9 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
|
# Fix import for some custom nodes, TODO: delete eventually.
|
||||||
|
RMSNorm = None
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||||
@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||||
q = self.query_norm(q)
|
q = self.query_norm(q)
|
||||||
@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
@ -224,32 +214,17 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
del txt_qkv
|
del txt_qkv
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
if self.flipped_img_txt:
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
q = torch.cat((img_q, txt_q), dim=2)
|
del txt_q, img_q
|
||||||
del img_q, txt_q
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
k = torch.cat((img_k, txt_k), dim=2)
|
del txt_k, img_k
|
||||||
del img_k, txt_k
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
v = torch.cat((img_v, txt_v), dim=2)
|
del txt_v, img_v
|
||||||
del img_v, txt_v
|
# run actual attention
|
||||||
# run actual attention
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
attn = attention(q, k, v,
|
del q, k, v
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
else:
|
|
||||||
q = torch.cat((txt_q, img_q), dim=2)
|
|
||||||
del txt_q, img_q
|
|
||||||
k = torch.cat((txt_k, img_k), dim=2)
|
|
||||||
del txt_k, img_k
|
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
|
||||||
del txt_v, img_v
|
|
||||||
# run actual attention
|
|
||||||
attn = attention(q, k, v,
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from .layers import (
|
|||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
Modulation,
|
Modulation,
|
||||||
RMSNorm
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -81,7 +80,7 @@ class Flux(nn.Module):
|
|||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
if params.txt_norm:
|
if params.txt_norm:
|
||||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.txt_norm = None
|
self.txt_norm = None
|
||||||
|
|
||||||
|
|||||||
@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
flipped_img_txt=True,
|
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||||
|
|
||||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
img_len = img.shape[1]
|
img_len = img.shape[1]
|
||||||
if txt_mask is not None:
|
if txt_mask is not None:
|
||||||
attn_mask_len = img_len + txt.shape[1]
|
attn_mask_len = img_len + txt.shape[1]
|
||||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||||
attn_mask[:, 0, img_len:] = txt_mask
|
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img += add
|
||||||
|
|
||||||
img = torch.cat((img, txt), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
|
|||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
add = control_o[i]
|
add = control_o[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img[:, : img_len] += add
|
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||||
|
|
||||||
img = img[:, : img_len]
|
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||||
if ref_latent is not None:
|
if ref_latent is not None:
|
||||||
img = img[:, ref_latent.shape[1]:]
|
img = img[:, ref_latent.shape[1]:]
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,196 @@ import torch
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from .model import QwenImageTransformer2DModel
|
from .model import QwenImageTransformer2DModel
|
||||||
|
from .model import QwenImageTransformerBlock
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||||
|
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__(
|
||||||
|
dim=dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.has_before_proj = has_before_proj
|
||||||
|
if has_before_proj:
|
||||||
|
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
control_in_features=132,
|
||||||
|
inner_dim=3072,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.main_model_double = main_model_double
|
||||||
|
self.injection_layers = tuple(injection_layers)
|
||||||
|
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||||
|
# to the reference Gen2/VideoX implementation around strength=1.
|
||||||
|
self.hint_scale = 1.0
|
||||||
|
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.control_blocks = torch.nn.ModuleList([])
|
||||||
|
for i in range(num_control_blocks):
|
||||||
|
self.control_blocks.append(
|
||||||
|
QwenImageFunControlBlock(
|
||||||
|
dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
has_before_proj=(i == 0),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_hint_tokens(self, hint):
|
||||||
|
if hint is None:
|
||||||
|
return None
|
||||||
|
if hint.ndim == 4:
|
||||||
|
hint = hint.unsqueeze(2)
|
||||||
|
|
||||||
|
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||||
|
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||||
|
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||||
|
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||||
|
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||||
|
if hint.shape[1] == 16 and expected_c == 33:
|
||||||
|
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||||
|
zeros_inpaint = torch.zeros_like(hint)
|
||||||
|
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||||
|
|
||||||
|
bs, c, t, h, w = hint.shape
|
||||||
|
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
orig_shape[0],
|
||||||
|
orig_shape[1],
|
||||||
|
orig_shape[-3],
|
||||||
|
orig_shape[-2] // 2,
|
||||||
|
2,
|
||||||
|
orig_shape[-1] // 2,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
bs,
|
||||||
|
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||||
|
c * 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_in = self.control_img_in.weight.shape[1]
|
||||||
|
cur_in = hidden_states.shape[-1]
|
||||||
|
if cur_in < expected_in:
|
||||||
|
pad = torch.zeros(
|
||||||
|
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||||
|
elif cur_in > expected_in:
|
||||||
|
hidden_states = hidden_states[:, :, :expected_in]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
hint=None,
|
||||||
|
transformer_options={},
|
||||||
|
base_model=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if base_model is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||||
|
|
||||||
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||||
|
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||||
|
encoder_hidden_states_mask = None
|
||||||
|
|
||||||
|
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||||
|
hint_tokens = self._process_hint_tokens(hint)
|
||||||
|
if hint_tokens is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||||
|
|
||||||
|
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||||
|
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||||
|
hint_tokens = hint_tokens[:, :max_tokens]
|
||||||
|
hidden_states = hidden_states[:, :max_tokens]
|
||||||
|
img_ids = img_ids[:, :max_tokens]
|
||||||
|
|
||||||
|
txt_start = round(
|
||||||
|
max(
|
||||||
|
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
|
|
||||||
|
hidden_states = base_model.img_in(hidden_states)
|
||||||
|
encoder_hidden_states = base_model.txt_norm(context)
|
||||||
|
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
|
||||||
|
temb = (
|
||||||
|
base_model.time_text_embed(timesteps, hidden_states)
|
||||||
|
if guidance is None
|
||||||
|
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
c = self.control_img_in(hint_tokens)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.control_blocks):
|
||||||
|
if i == 0:
|
||||||
|
c_in = block.before_proj(c) + hidden_states
|
||||||
|
all_c = []
|
||||||
|
else:
|
||||||
|
all_c = list(torch.unbind(c, dim=0))
|
||||||
|
c_in = all_c.pop(-1)
|
||||||
|
|
||||||
|
encoder_hidden_states, c_out = block(
|
||||||
|
hidden_states=c_in,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||||
|
all_c += [c_skip, c_out]
|
||||||
|
c = torch.stack(all_c, dim=0)
|
||||||
|
|
||||||
|
hints = torch.unbind(c, dim=0)[:-1]
|
||||||
|
|
||||||
|
controlnet_block_samples = [None] * self.main_model_double
|
||||||
|
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||||
|
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||||
|
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||||
|
|
||||||
|
return {"input": controlnet_block_samples}
|
||||||
|
|
||||||
|
|
||||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import comfy.utils
|
|||||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||||
sd_out = {}
|
sd_out = {}
|
||||||
for k in sd:
|
for k in sd:
|
||||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
|
||||||
sd_out[k_to] = sd[k]
|
sd_out[k_to] = sd[k]
|
||||||
|
|
||||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||||
|
|||||||
@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
def any_suffix_in(keys, prefix, main, suffix_list=[]):
|
||||||
|
for x in suffix_list:
|
||||||
|
if "{}{}{}".format(prefix, main, x) in keys:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||||
context_dim = None
|
context_dim = None
|
||||||
use_linear_in_transformer = False
|
use_linear_in_transformer = False
|
||||||
@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["meanflow_sum"] = False
|
dit_config["meanflow_sum"] = False
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["image_model"] = "flux2"
|
dit_config["image_model"] = "flux2"
|
||||||
@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
|
||||||
dit_config["image_model"] = "chroma"
|
dit_config["image_model"] = "chroma"
|
||||||
dit_config["in_channels"] = 64
|
dit_config["in_channels"] = 64
|
||||||
dit_config["out_channels"] = 64
|
dit_config["out_channels"] = 64
|
||||||
@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["out_dim"] = 3072
|
dit_config["out_dim"] = 3072
|
||||||
dit_config["hidden_dim"] = 5120
|
dit_config["hidden_dim"] = 5120
|
||||||
dit_config["n_layers"] = 5
|
dit_config["n_layers"] = 5
|
||||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
|
||||||
dit_config["image_model"] = "chroma_radiance"
|
dit_config["image_model"] = "chroma_radiance"
|
||||||
dit_config["in_channels"] = 3
|
dit_config["in_channels"] = 3
|
||||||
dit_config["out_channels"] = 3
|
dit_config["out_channels"] = 3
|
||||||
@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_depth"] = 4
|
dit_config["nerf_depth"] = 4
|
||||||
dit_config["nerf_max_freqs"] = 8
|
dit_config["nerf_max_freqs"] = 8
|
||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
dit_config["txt_ids_dims"] = [1, 2]
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
|
||||||
|
|||||||
@ -679,18 +679,19 @@ class ModelPatcher:
|
|||||||
for key in list(self.pinned):
|
for key in list(self.pinned):
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self, prio_comfy_cast_weights=False):
|
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
params = []
|
default = False
|
||||||
skip = False
|
params = { name: param for name, param in m.named_parameters(recurse=False) }
|
||||||
for name, param in m.named_parameters(recurse=False):
|
|
||||||
params.append(name)
|
|
||||||
for name, param in m.named_parameters(recurse=True):
|
for name, param in m.named_parameters(recurse=True):
|
||||||
if name not in params:
|
if name not in params:
|
||||||
skip = True # skip random weights in non leaf modules
|
default = True # default random weights in non leaf modules
|
||||||
break
|
break
|
||||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
if default and default_device is not None:
|
||||||
|
for param in params.values():
|
||||||
|
param.data = param.data.to(device=default_device)
|
||||||
|
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
module_offload_mem = module_mem
|
module_offload_mem = module_mem
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
@ -1495,7 +1496,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||||
#with a high layer rate (e.g. autoregressive LLMs).
|
#with a high layer rate (e.g. autoregressive LLMs).
|
||||||
#prioritize the non-comfy weights (note the order reverse).
|
#prioritize the non-comfy weights (note the order reverse).
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
|
|
||||||
for x in loading:
|
for x in loading:
|
||||||
@ -1560,6 +1561,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
allocated_size += weight_size
|
allocated_size += weight_size
|
||||||
vbar.set_watermark_limit(allocated_size)
|
vbar.set_watermark_limit(allocated_size)
|
||||||
|
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||||
|
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
@ -1579,7 +1582,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
_, _, _, _, m, _ = x
|
_, _, _, _, m, _ = x
|
||||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||||
@ -1600,6 +1603,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.partially_unload_ram(1e32)
|
self.partially_unload_ram(1e32)
|
||||||
self.partially_unload(None, 1e32)
|
self.partially_unload(None, 1e32)
|
||||||
|
for m in self.model.modules():
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
|
|||||||
@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
def process_tokens(self, tokens, device):
|
||||||
end_token = self.special_tokens.get("end", None)
|
end_token = self.special_tokens.get("end", None)
|
||||||
|
pad_token = self.special_tokens.get("pad", -1)
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
cmp_token = self.special_tokens.get("pad", -1)
|
cmp_token = pad_token
|
||||||
else:
|
else:
|
||||||
cmp_token = end_token
|
cmp_token = end_token
|
||||||
|
|
||||||
@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
other_embeds = []
|
other_embeds = []
|
||||||
eos = False
|
eos = False
|
||||||
index = 0
|
index = 0
|
||||||
|
left_pad = False
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, numbers.Integral):
|
if isinstance(y, numbers.Integral):
|
||||||
if eos:
|
token = int(y)
|
||||||
|
if index == 0 and token == pad_token:
|
||||||
|
left_pad = True
|
||||||
|
|
||||||
|
if eos or (left_pad and token == pad_token):
|
||||||
attention_mask.append(0)
|
attention_mask.append(0)
|
||||||
else:
|
else:
|
||||||
attention_mask.append(1)
|
attention_mask.append(1)
|
||||||
token = int(y)
|
left_pad = False
|
||||||
|
|
||||||
tokens_temp += [token]
|
tokens_temp += [token]
|
||||||
if not eos and token == cmp_token:
|
if not eos and token == cmp_token and not left_pad:
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
attention_mask[-1] = 0
|
attention_mask[-1] = 0
|
||||||
eos = True
|
eos = True
|
||||||
|
|||||||
@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith("_norm.scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||||
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||||
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
|
||||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
|
||||||
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||||
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||||
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
out_sd[key_out] = state_dict[k]
|
out_sd[key_out] = state_dict[k]
|
||||||
return out_sd
|
return out_sd
|
||||||
|
|
||||||
@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.Hunyuan3Dv2
|
latent_format = latent_formats.Hunyuan3Dv2
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
def process_unet_state_dict_for_saving(self, state_dict):
|
def process_unet_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {"": "model."}
|
replace_prefix = {"": "model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
|
|||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Chroma(self, device=device)
|
out = model_base.Chroma(self, device=device)
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import comfy.utils
|
|||||||
def sample_manual_loop_no_classes(
|
def sample_manual_loop_no_classes(
|
||||||
model,
|
model,
|
||||||
ids=None,
|
ids=None,
|
||||||
paddings=[],
|
|
||||||
execution_dtype=None,
|
execution_dtype=None,
|
||||||
cfg_scale: float = 2.0,
|
cfg_scale: float = 2.0,
|
||||||
temperature: float = 0.85,
|
temperature: float = 0.85,
|
||||||
@ -36,9 +35,6 @@ def sample_manual_loop_no_classes(
|
|||||||
|
|
||||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||||
embeds_batch = embeds.shape[0]
|
embeds_batch = embeds.shape[0]
|
||||||
for i, t in enumerate(paddings):
|
|
||||||
attention_mask[i, :t] = 0
|
|
||||||
attention_mask[i, t:] = 1
|
|
||||||
|
|
||||||
output_audio_codes = []
|
output_audio_codes = []
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
@ -135,13 +131,11 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
|
|||||||
pos_pad = (len(negative) - len(positive))
|
pos_pad = (len(negative) - len(positive))
|
||||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||||
|
|
||||||
paddings = [pos_pad, neg_pad]
|
|
||||||
ids = [positive, negative]
|
ids = [positive, negative]
|
||||||
else:
|
else:
|
||||||
paddings = []
|
|
||||||
ids = [positive]
|
ids = [positive]
|
||||||
|
|
||||||
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||||
|
|
||||||
|
|
||||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
|||||||
@ -355,13 +355,6 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||||
if not isinstance(theta, list):
|
if not isinstance(theta, list):
|
||||||
theta = [theta]
|
theta = [theta]
|
||||||
@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
|||||||
else:
|
else:
|
||||||
cos = cos.unsqueeze(1)
|
cos = cos.unsqueeze(1)
|
||||||
sin = sin.unsqueeze(1)
|
sin = sin.unsqueeze(1)
|
||||||
out.append((cos, sin))
|
sin_split = sin.shape[-1] // 2
|
||||||
|
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||||
|
|
||||||
if len(out) == 1:
|
if len(out) == 1:
|
||||||
return out[0]
|
return out[0]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
org_dtype = xq.dtype
|
org_dtype = xq.dtype
|
||||||
cos = freqs_cis[0]
|
cos = freqs_cis[0]
|
||||||
sin = freqs_cis[1]
|
sin = freqs_cis[1]
|
||||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
nsin = freqs_cis[2]
|
||||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
|
||||||
|
q_embed = (xq * cos)
|
||||||
|
q_split = q_embed.shape[-1] // 2
|
||||||
|
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
|
||||||
|
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
|
||||||
|
|
||||||
|
k_embed = (xk * cos)
|
||||||
|
k_split = k_embed.shape[-1] // 2
|
||||||
|
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
|
||||||
|
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
|
||||||
|
|
||||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
|
|||||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||||
|
|
||||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||||
|
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
|
||||||
out_device = out.device
|
out_device = out.device
|
||||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||||
@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||||
|
num_tokens = max(num_tokens, 64)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
|||||||
@ -675,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
||||||
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
||||||
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
||||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
|
||||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
|
||||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
|
||||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
@ -701,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"norm.linear.bias": "modulation.lin.bias",
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
"proj_out.weight": "linear2.weight",
|
"proj_out.weight": "linear2.weight",
|
||||||
"proj_out.bias": "linear2.bias",
|
"proj_out.bias": "linear2.bias",
|
||||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
"attn.norm_q.weight": "norm.query_norm.weight",
|
||||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
"attn.norm_k.weight": "norm.key_norm.weight",
|
||||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1035,7 +1035,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
io.Boolean.Input(
|
io.Boolean.Input(
|
||||||
"offloading",
|
"offloading",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Depth level for gradient checkpointing.",
|
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"existing_lora",
|
"existing_lora",
|
||||||
@ -1124,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
|
if mp.is_dynamic():
|
||||||
|
if not bypass_mode:
|
||||||
|
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||||
|
bypass_mode = True
|
||||||
|
offloading = True
|
||||||
|
elif offloading:
|
||||||
|
if not bypass_mode:
|
||||||
|
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||||
latents, dtype, bucket_mode
|
latents, dtype, bucket_mode
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user