Compare commits

...

5 Commits

Author SHA1 Message Date
E-Anlia
5599f3a9da
Merge 29655ed6fa into 6592bffc60 2025-12-14 13:04:40 +08:00
chaObserv
6592bffc60
seeds_2: add phi_2 variant and sampler node (#11309)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* Add phi_2 solver type to seeds_2

* Add sampler node of seeds_2
2025-12-14 00:03:29 -05:00
Anlia
29655ed6fa Reuse NextDiT backbone; unify integration and improve interface/perf 2025-12-14 02:47:41 +08:00
E-Anlia
5fcd6c5f79
Add files via upload 2025-12-12 16:01:57 +08:00
Anlia
4c08fd2150 Added support for NewBieModel 2025-12-12 15:20:23 +08:00
8 changed files with 1959 additions and 1580 deletions

View File

@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if solver_type == "phi_1":
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()

View File

@ -377,7 +377,6 @@ class NextDiT(nn.Module):
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
clip_text_dim=None,
image_model=None,
device=None,
dtype=None,
@ -448,31 +447,6 @@ class NextDiT(nn.Module):
),
)
self.clip_text_pooled_proj = None
if clip_text_dim is not None:
self.clip_text_dim = clip_text_dim
self.clip_text_pooled_proj = nn.Sequential(
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
clip_text_dim,
clip_text_dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.time_text_embed = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024) + clip_text_dim,
min(dim, 1024),
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList(
[
JointTransformerBlock(
@ -620,15 +594,6 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)

View File

@ -0,0 +1,54 @@
import warnings
import torch
import torch.nn as nn
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight

195
comfy/ldm/newbie/model.py Normal file
View File

@ -0,0 +1,195 @@
from __future__ import annotations
from typing import Optional, Any, Dict
import torch
import torch.nn as nn
import comfy.ldm.common_dit as common_dit
from comfy.ldm.lumina.model import NextDiT as NextDiTBase
from .components import RMSNorm
#######################################################
# Adds support for NewBie image #
#######################################################
def _fallback_operations():
try:
import comfy.ops
return comfy.ops.disable_weight_init
except Exception:
return None
def _pop_unexpected_kwargs(kwargs: Dict[str, Any]) -> None:
for k in (
"model_type",
"operation_settings",
"unet_dtype",
"weight_dtype",
"precision",
"extra_model_config",
):
kwargs.pop(k, None)
class NewBieNextDiT_CLIP(NextDiTBase):
def __init__(
self,
*args,
clip_text_dim: int = 1024,
clip_img_dim: int = 1024,
device=None,
dtype=None,
operations=None,
**kwargs,
):
_pop_unexpected_kwargs(kwargs)
if operations is None:
operations = _fallback_operations()
super().__init__(*args, device=device, dtype=dtype, operations=operations, **kwargs)
self._nb_device = device
self._nb_dtype = dtype
self._nb_ops = operations
min_mod = min(int(getattr(self, "dim", 1024)), 1024)
if operations is not None and hasattr(operations, "Linear"):
Linear = operations.Linear
Norm = getattr(operations, "RMSNorm", None)
else:
Linear = nn.Linear
Norm = None
if Norm is not None:
self.clip_text_pooled_proj = nn.Sequential(
Norm(clip_text_dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype),
Linear(clip_text_dim, clip_text_dim, bias=True, device=device, dtype=dtype),
)
else:
self.clip_text_pooled_proj = nn.Sequential(
RMSNorm(clip_text_dim),
nn.Linear(clip_text_dim, clip_text_dim, bias=True),
)
nn.init.normal_(self.clip_text_pooled_proj[1].weight, std=0.01)
nn.init.zeros_(self.clip_text_pooled_proj[1].bias)
self.time_text_embed = nn.Sequential(
nn.SiLU(),
Linear(min_mod + clip_text_dim, min_mod, bias=True, device=device, dtype=dtype),
)
nn.init.zeros_(self.time_text_embed[1].weight)
nn.init.zeros_(self.time_text_embed[1].bias)
if Norm is not None:
self.clip_img_pooled_embedder = nn.Sequential(
Norm(clip_img_dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype),
Linear(clip_img_dim, min_mod, bias=True, device=device, dtype=dtype),
)
else:
self.clip_img_pooled_embedder = nn.Sequential(
RMSNorm(clip_img_dim),
nn.Linear(clip_img_dim, min_mod, bias=True),
)
nn.init.normal_(self.clip_img_pooled_embedder[1].weight, std=0.01)
nn.init.zeros_(self.clip_img_pooled_embedder[1].bias)
@staticmethod
def _get_clip_from_kwargs(transformer_options: dict, kwargs: dict, key: str):
if key in kwargs:
return kwargs.get(key)
if transformer_options is not None and key in transformer_options:
return transformer_options.get(key)
extra = transformer_options.get("extra_cond", None) if transformer_options else None
if isinstance(extra, dict) and key in extra:
return extra.get(key)
return None
def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
num_tokens: int,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: dict = {},
**kwargs,
):
t = timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
t_emb = self.t_embedder(t, dtype=x.dtype)
adaln_input = t_emb
clip_text_pooled = self._get_clip_from_kwargs(transformer_options, kwargs, "clip_text_pooled")
clip_img_pooled = self._get_clip_from_kwargs(transformer_options, kwargs, "clip_img_pooled")
if clip_text_pooled is not None:
if clip_text_pooled.dim() > 2:
clip_text_pooled = clip_text_pooled.view(clip_text_pooled.shape[0], -1)
clip_text_pooled = clip_text_pooled.to(device=t_emb.device, dtype=t_emb.dtype)
clip_emb = self.clip_text_pooled_proj(clip_text_pooled)
adaln_input = self.time_text_embed(torch.cat([t_emb, clip_emb], dim=-1))
if clip_img_pooled is not None:
if clip_img_pooled.dim() > 2:
clip_img_pooled = clip_img_pooled.view(clip_img_pooled.shape[0], -1)
clip_img_pooled = clip_img_pooled.to(device=t_emb.device, dtype=t_emb.dtype)
adaln_input = adaln_input + self.clip_img_pooled_embedder(clip_img_pooled)
if isinstance(cap_feats, torch.Tensor):
try:
target_dtype = next(self.cap_embedder.parameters()).dtype
except StopIteration:
target_dtype = cap_feats.dtype
cap_feats = cap_feats.to(device=t_emb.device, dtype=target_dtype)
cap_feats = self.cap_embedder(cap_feats)
patches = transformer_options.get("patches", {})
x_is_tensor = True
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options
)
freqs_cis = freqs_cis.to(img.device)
for i, layer in enumerate(self.layers):
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p(
{
"img": img[:, cap_size[0] :],
"txt": img[:, : cap_size[0]],
"pe": freqs_cis[:, cap_size[0] :],
"vec": adaln_input,
"x": x,
"block_index": i,
"transformer_options": transformer_options,
}
)
if isinstance(out, dict):
if "img" in out:
img[:, cap_size[0] :] = out["img"]
if "txt" in out:
img[:, : cap_size[0]] = out["txt"]
img = self.final_layer(img, adaln_input)
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)
img = img[:, :, :h, :w]
return img
def NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP(**kwargs):
_pop_unexpected_kwargs(kwargs)
kwargs.setdefault("patch_size", 2)
kwargs.setdefault("in_channels", 16)
kwargs.setdefault("dim", 2304)
kwargs.setdefault("n_layers", 36)
kwargs.setdefault("n_heads", 24)
kwargs.setdefault("n_kv_heads", 8)
kwargs.setdefault("axes_dims", [32, 32, 32])
kwargs.setdefault("axes_lens", [1024, 512, 512])
return NewBieNextDiT_CLIP(**kwargs)
def NewBieNextDiT(*, device=None, dtype=None, operations=None, **kwargs):
_pop_unexpected_kwargs(kwargs)
if operations is None:
operations = _fallback_operations()
if dtype is None:
dev_str = str(device) if device is not None else ""
if dev_str.startswith("cuda") and torch.cuda.is_available():
if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
dtype = torch.bfloat16
else:
dtype = torch.float16
else:
dtype = torch.float32
model = NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP(
device=device, dtype=dtype, operations=operations, **kwargs
)
return model

View File

@ -928,6 +928,90 @@ class Flux2(Flux):
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class NewBieImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
import comfy.ldm.newbie.model as nb
super().__init__(model_config, model_type, device=device, unet_model=nb.NewBieNextDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out["c_crossattn"] = comfy.conds.CONDCrossAttn(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
cap_feats = kwargs.get("cap_feats", None)
if cap_feats is not None:
out["cap_feats"] = comfy.conds.CONDRegular(cap_feats)
cap_mask = kwargs.get("cap_mask", None)
if cap_mask is not None:
out["cap_mask"] = comfy.conds.CONDRegular(cap_mask)
clip_text_pooled = kwargs.get("clip_text_pooled", None)
if clip_text_pooled is not None:
out["clip_text_pooled"] = comfy.conds.CONDRegular(clip_text_pooled)
clip_img_pooled = kwargs.get("clip_img_pooled", None)
if clip_img_pooled is not None:
out["clip_img_pooled"] = comfy.conds.CONDRegular(clip_img_pooled)
return out
def extra_conds_shapes(self, **kwargs):
out = super().extra_conds_shapes(**kwargs)
cap_feats = kwargs.get("cap_feats", None)
if cap_feats is not None:
out["cap_feats"] = list(cap_feats.shape)
clip_text_pooled = kwargs.get("clip_text_pooled", None)
if clip_text_pooled is not None:
out["clip_text_pooled"] = list(clip_text_pooled.shape)
clip_img_pooled = kwargs.get("clip_img_pooled", None)
if clip_img_pooled is not None:
out["clip_img_pooled"] = list(clip_img_pooled.shape)
return out
def apply_model(
self, x, t,
c_concat=None, c_crossattn=None,
control=None, transformer_options={}, **kwargs
):
sigma = t
try:
model_device = next(self.diffusion_model.parameters()).device
except StopIteration:
model_device = x.device
x_in = x.to(device=model_device)
sigma_in = sigma.to(device=model_device)
xc = self.model_sampling.calculate_input(sigma_in, x_in)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat.to(device=model_device)], dim=1)
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype=dtype)
t_val = (1.0 - sigma_in).to(dtype=torch.float32)
cap_feats = kwargs.get("cap_feats", kwargs.get("cross_attn", c_crossattn))
cap_mask = kwargs.get("cap_mask", kwargs.get("attention_mask"))
clip_text_pooled = kwargs.get("clip_text_pooled")
clip_img_pooled = kwargs.get("clip_img_pooled")
if cap_feats is not None:
cap_feats = cap_feats.to(device=model_device, dtype=dtype)
if cap_mask is None and cap_feats is not None:
cap_mask = torch.ones(cap_feats.shape[:2], dtype=torch.bool, device=model_device)
elif cap_mask is not None:
cap_mask = cap_mask.to(device=model_device)
if cap_mask.dtype != torch.bool:
cap_mask = cap_mask != 0
model_kwargs = {}
if clip_text_pooled is not None:
model_kwargs["clip_text_pooled"] = clip_text_pooled.to(device=model_device, dtype=dtype)
if clip_img_pooled is not None:
model_kwargs["clip_img_pooled"] = clip_img_pooled.to(device=model_device, dtype=dtype)
model_output = self.diffusion_model(xc, t_val, cap_feats, cap_mask, **model_kwargs).float()
model_output = -model_output
denoised = self.model_sampling.calculate_denoised(sigma_in, model_output, x_in)
if denoised.device != x.device:
denoised = denoised.to(device=x.device)
return denoised
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@ -1110,10 +1194,6 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs["pooled_output"] # Newbie
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
return out
class WAN21(BaseModel):

View File

@ -6,6 +6,26 @@ import math
import logging
import torch
def is_newbie_unet_state_dict(state_dict, key_prefix):
state_dict_keys = state_dict.keys()
try:
x_embed = state_dict[f"{key_prefix}x_embedder.weight"]
final = state_dict[f"{key_prefix}final_layer.linear.weight"]
except KeyError:
return False
if x_embed.ndim != 2:
return False
dim = x_embed.shape[0]
patch_dim = x_embed.shape[1]
if dim != 2304 or patch_dim != 64:
return False
if final.shape[0] != patch_dim or final.shape[1] != dim:
return False
n_layers = count_blocks(state_dict_keys, f"{key_prefix}layers." + "{}.")
if n_layers != 36:
return False
return True
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@ -411,7 +431,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 / NewBie image
dit_config = {}
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
@ -422,6 +442,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["qk_norm"] = True
if dit_config["dim"] == 2304 and is_newbie_unet_state_dict(state_dict, key_prefix): # NewBie image
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [1024, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["model_type"] = "newbie_dit"
dit_config["image_model"] = "NewBieImage"
return dit_config
if dit_config["dim"] == 2304: # Original Lumina 2
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
@ -429,9 +459,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None:
dit_config["clip_text_dim"] = ctd_weight.shape[0]
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30

File diff suppressed because it is too large Load Diff

View File

@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute
class SamplerSEEDS2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
],
outputs=[io.Sampler.Output()]
)
@classmethod
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
sampler_name = "seeds_2"
sampler = comfy.samplers.ksampler(
sampler_name,
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
)
return io.NodeOutput(sampler)
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative,
SamplerER_SDE,
SamplerSASolver,
SamplerSEEDS2,
SplitSigmas,
SplitSigmasDenoise,
FlipSigmas,