Added support for NewBieModel

This commit is contained in:
Anlia 2025-12-12 15:20:23 +08:00
parent 5495589db3
commit 4c08fd2150
6 changed files with 1400 additions and 44 deletions

View File

@ -377,7 +377,6 @@ class NextDiT(nn.Module):
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
clip_text_dim=None,
image_model=None,
device=None,
dtype=None,
@ -448,31 +447,6 @@ class NextDiT(nn.Module):
),
)
self.clip_text_pooled_proj = None
if clip_text_dim is not None:
self.clip_text_dim = clip_text_dim
self.clip_text_pooled_proj = nn.Sequential(
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
clip_text_dim,
clip_text_dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.time_text_embed = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024) + clip_text_dim,
min(dim, 1024),
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList(
[
JointTransformerBlock(
@ -611,15 +585,6 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)

View File

@ -0,0 +1,54 @@
import warnings
import torch
import torch.nn as nn
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight

1207
comfy/ldm/newbie/model.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -928,6 +928,90 @@ class Flux2(Flux):
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class NewBieImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
import comfy.ldm.newbie.model as nb
super().__init__(model_config, model_type, device=device, unet_model=nb.NewBieNextDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out["c_crossattn"] = comfy.conds.CONDCrossAttn(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
cap_feats = kwargs.get("cap_feats", None)
if cap_feats is not None:
out["cap_feats"] = comfy.conds.CONDRegular(cap_feats)
cap_mask = kwargs.get("cap_mask", None)
if cap_mask is not None:
out["cap_mask"] = comfy.conds.CONDRegular(cap_mask)
clip_text_pooled = kwargs.get("clip_text_pooled", None)
if clip_text_pooled is not None:
out["clip_text_pooled"] = comfy.conds.CONDRegular(clip_text_pooled)
clip_img_pooled = kwargs.get("clip_img_pooled", None)
if clip_img_pooled is not None:
out["clip_img_pooled"] = comfy.conds.CONDRegular(clip_img_pooled)
return out
def extra_conds_shapes(self, **kwargs):
out = super().extra_conds_shapes(**kwargs)
cap_feats = kwargs.get("cap_feats", None)
if cap_feats is not None:
out["cap_feats"] = list(cap_feats.shape)
clip_text_pooled = kwargs.get("clip_text_pooled", None)
if clip_text_pooled is not None:
out["clip_text_pooled"] = list(clip_text_pooled.shape)
clip_img_pooled = kwargs.get("clip_img_pooled", None)
if clip_img_pooled is not None:
out["clip_img_pooled"] = list(clip_img_pooled.shape)
return out
def apply_model(
self, x, t,
c_concat=None, c_crossattn=None,
control=None, transformer_options={}, **kwargs
):
sigma = t
try:
model_device = next(self.diffusion_model.parameters()).device
except StopIteration:
model_device = x.device
x_in = x.to(device=model_device)
sigma_in = sigma.to(device=model_device)
xc = self.model_sampling.calculate_input(sigma_in, x_in)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat.to(device=model_device)], dim=1)
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype=dtype)
t_val = (1.0 - sigma_in).to(dtype=torch.float32)
cap_feats = kwargs.get("cap_feats", kwargs.get("cross_attn", c_crossattn))
cap_mask = kwargs.get("cap_mask", kwargs.get("attention_mask"))
clip_text_pooled = kwargs.get("clip_text_pooled")
clip_img_pooled = kwargs.get("clip_img_pooled")
if cap_feats is not None:
cap_feats = cap_feats.to(device=model_device, dtype=dtype)
if cap_mask is None and cap_feats is not None:
cap_mask = torch.ones(cap_feats.shape[:2], dtype=torch.bool, device=model_device)
elif cap_mask is not None:
cap_mask = cap_mask.to(device=model_device)
if cap_mask.dtype != torch.bool:
cap_mask = cap_mask != 0
model_kwargs = {}
if clip_text_pooled is not None:
model_kwargs["clip_text_pooled"] = clip_text_pooled.to(device=model_device, dtype=dtype)
if clip_img_pooled is not None:
model_kwargs["clip_img_pooled"] = clip_img_pooled.to(device=model_device, dtype=dtype)
model_output = self.diffusion_model(xc, t_val, cap_feats, cap_mask, **model_kwargs).float()
model_output = -model_output
denoised = self.model_sampling.calculate_denoised(sigma_in, model_output, x_in)
if denoised.device != x.device:
denoised = denoised.to(device=x.device)
return denoised
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@ -1110,10 +1194,6 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs["pooled_output"] # Newbie
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
return out
class WAN21(BaseModel):

View File

@ -6,6 +6,26 @@ import math
import logging
import torch
def is_newbie_unet_state_dict(state_dict, key_prefix):
state_dict_keys = state_dict.keys()
try:
x_embed = state_dict[f"{key_prefix}x_embedder.weight"]
final = state_dict[f"{key_prefix}final_layer.linear.weight"]
except KeyError:
return False
if x_embed.ndim != 2:
return False
dim = x_embed.shape[0]
patch_dim = x_embed.shape[1]
if dim != 2304 or patch_dim != 64:
return False
if final.shape[0] != patch_dim or final.shape[1] != dim:
return False
n_layers = count_blocks(state_dict_keys, f"{key_prefix}layers." + "{}.")
if n_layers != 36:
return False
return True
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@ -411,7 +431,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 / NewBie image
dit_config = {}
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
@ -422,6 +442,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["qk_norm"] = True
if dit_config["dim"] == 2304 and is_newbie_unet_state_dict(state_dict, key_prefix): # NewBie image
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [1024, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["model_type"] = "newbie_dit"
dit_config["image_model"] = "NewBieImage"
return dit_config
if dit_config["dim"] == 2304: # Original Lumina 2
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
@ -429,9 +459,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None:
dit_config["clip_text_dim"] = ctd_weight.shape[0]
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30

View File

@ -1035,6 +1035,29 @@ class ZImage(Lumina2):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class NewBieImageModel(supported_models_base.BASE):
unet_config = {
"image_model": "NewBieImage",
"model_type": "newbie_dit",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 6.0,
}
memory_usage_factor = 1.5
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.NewBieImage(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@ -1529,6 +1552,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, NewBieImageModel, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid]