mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
Added support for NewBieModel
This commit is contained in:
parent
5495589db3
commit
4c08fd2150
@ -377,7 +377,6 @@ class NextDiT(nn.Module):
|
||||
z_image_modulation=False,
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -448,31 +447,6 @@ class NextDiT(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
self.clip_text_pooled_proj = None
|
||||
|
||||
if clip_text_dim is not None:
|
||||
self.clip_text_dim = clip_text_dim
|
||||
self.clip_text_pooled_proj = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
clip_text_dim,
|
||||
clip_text_dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.time_text_embed = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024) + clip_text_dim,
|
||||
min(dim, 1024),
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
@ -611,15 +585,6 @@ class NextDiT(nn.Module):
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
|
||||
54
comfy/ldm/newbie/components.py
Normal file
54
comfy/ldm/newbie/components.py
Normal file
@ -0,0 +1,54 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except ImportError:
|
||||
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
1207
comfy/ldm/newbie/model.py
Normal file
1207
comfy/ldm/newbie/model.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -928,6 +928,90 @@ class Flux2(Flux):
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class NewBieImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
import comfy.ldm.newbie.model as nb
|
||||
super().__init__(model_config, model_type, device=device, unet_model=nb.NewBieNextDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out["c_crossattn"] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
|
||||
cap_feats = kwargs.get("cap_feats", None)
|
||||
if cap_feats is not None:
|
||||
out["cap_feats"] = comfy.conds.CONDRegular(cap_feats)
|
||||
cap_mask = kwargs.get("cap_mask", None)
|
||||
if cap_mask is not None:
|
||||
out["cap_mask"] = comfy.conds.CONDRegular(cap_mask)
|
||||
clip_text_pooled = kwargs.get("clip_text_pooled", None)
|
||||
if clip_text_pooled is not None:
|
||||
out["clip_text_pooled"] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
clip_img_pooled = kwargs.get("clip_img_pooled", None)
|
||||
if clip_img_pooled is not None:
|
||||
out["clip_img_pooled"] = comfy.conds.CONDRegular(clip_img_pooled)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = super().extra_conds_shapes(**kwargs)
|
||||
cap_feats = kwargs.get("cap_feats", None)
|
||||
if cap_feats is not None:
|
||||
out["cap_feats"] = list(cap_feats.shape)
|
||||
clip_text_pooled = kwargs.get("clip_text_pooled", None)
|
||||
if clip_text_pooled is not None:
|
||||
out["clip_text_pooled"] = list(clip_text_pooled.shape)
|
||||
clip_img_pooled = kwargs.get("clip_img_pooled", None)
|
||||
if clip_img_pooled is not None:
|
||||
out["clip_img_pooled"] = list(clip_img_pooled.shape)
|
||||
return out
|
||||
|
||||
def apply_model(
|
||||
self, x, t,
|
||||
c_concat=None, c_crossattn=None,
|
||||
control=None, transformer_options={}, **kwargs
|
||||
):
|
||||
sigma = t
|
||||
try:
|
||||
model_device = next(self.diffusion_model.parameters()).device
|
||||
except StopIteration:
|
||||
model_device = x.device
|
||||
x_in = x.to(device=model_device)
|
||||
sigma_in = sigma.to(device=model_device)
|
||||
xc = self.model_sampling.calculate_input(sigma_in, x_in)
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([xc] + [c_concat.to(device=model_device)], dim=1)
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
xc = xc.to(dtype=dtype)
|
||||
t_val = (1.0 - sigma_in).to(dtype=torch.float32)
|
||||
cap_feats = kwargs.get("cap_feats", kwargs.get("cross_attn", c_crossattn))
|
||||
cap_mask = kwargs.get("cap_mask", kwargs.get("attention_mask"))
|
||||
clip_text_pooled = kwargs.get("clip_text_pooled")
|
||||
clip_img_pooled = kwargs.get("clip_img_pooled")
|
||||
if cap_feats is not None:
|
||||
cap_feats = cap_feats.to(device=model_device, dtype=dtype)
|
||||
if cap_mask is None and cap_feats is not None:
|
||||
cap_mask = torch.ones(cap_feats.shape[:2], dtype=torch.bool, device=model_device)
|
||||
elif cap_mask is not None:
|
||||
cap_mask = cap_mask.to(device=model_device)
|
||||
if cap_mask.dtype != torch.bool:
|
||||
cap_mask = cap_mask != 0
|
||||
model_kwargs = {}
|
||||
if clip_text_pooled is not None:
|
||||
model_kwargs["clip_text_pooled"] = clip_text_pooled.to(device=model_device, dtype=dtype)
|
||||
if clip_img_pooled is not None:
|
||||
model_kwargs["clip_img_pooled"] = clip_img_pooled.to(device=model_device, dtype=dtype)
|
||||
model_output = self.diffusion_model(xc, t_val, cap_feats, cap_mask, **model_kwargs).float()
|
||||
model_output = -model_output
|
||||
denoised = self.model_sampling.calculate_denoised(sigma_in, model_output, x_in)
|
||||
if denoised.device != x.device:
|
||||
denoised = denoised.to(device=x.device)
|
||||
return denoised
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@ -1110,10 +1194,6 @@ class Lumina2(BaseModel):
|
||||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
|
||||
@ -6,6 +6,26 @@ import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
def is_newbie_unet_state_dict(state_dict, key_prefix):
|
||||
state_dict_keys = state_dict.keys()
|
||||
try:
|
||||
x_embed = state_dict[f"{key_prefix}x_embedder.weight"]
|
||||
final = state_dict[f"{key_prefix}final_layer.linear.weight"]
|
||||
except KeyError:
|
||||
return False
|
||||
if x_embed.ndim != 2:
|
||||
return False
|
||||
dim = x_embed.shape[0]
|
||||
patch_dim = x_embed.shape[1]
|
||||
if dim != 2304 or patch_dim != 64:
|
||||
return False
|
||||
if final.shape[0] != patch_dim or final.shape[1] != dim:
|
||||
return False
|
||||
n_layers = count_blocks(state_dict_keys, f"{key_prefix}layers." + "{}.")
|
||||
if n_layers != 36:
|
||||
return False
|
||||
return True
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
@ -411,7 +431,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 / NewBie image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
@ -422,6 +442,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["qk_norm"] = True
|
||||
|
||||
if dit_config["dim"] == 2304 and is_newbie_unet_state_dict(state_dict, key_prefix): # NewBie image
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [1024, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["model_type"] = "newbie_dit"
|
||||
dit_config["image_model"] = "NewBieImage"
|
||||
return dit_config
|
||||
|
||||
if dit_config["dim"] == 2304: # Original Lumina 2
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
@ -429,9 +459,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None:
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
|
||||
@ -1035,6 +1035,29 @@ class ZImage(Lumina2):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
||||
|
||||
class NewBieImageModel(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "NewBieImage",
|
||||
"model_type": "newbie_dit",
|
||||
}
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 6.0,
|
||||
}
|
||||
memory_usage_factor = 1.5
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.NewBieImage(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1529,6 +1552,6 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, NewBieImageModel, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user