mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
Compare commits
5 Commits
aa8a19e2da
...
5599f3a9da
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5599f3a9da | ||
|
|
6592bffc60 | ||
|
|
29655ed6fa | ||
|
|
5fcd6c5f79 | ||
|
|
4c08fd2150 |
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
|
||||
@ -377,7 +377,6 @@ class NextDiT(nn.Module):
|
||||
z_image_modulation=False,
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -448,31 +447,6 @@ class NextDiT(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
self.clip_text_pooled_proj = None
|
||||
|
||||
if clip_text_dim is not None:
|
||||
self.clip_text_dim = clip_text_dim
|
||||
self.clip_text_pooled_proj = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
clip_text_dim,
|
||||
clip_text_dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.time_text_embed = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024) + clip_text_dim,
|
||||
min(dim, 1024),
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
@ -620,15 +594,6 @@ class NextDiT(nn.Module):
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
|
||||
54
comfy/ldm/newbie/components.py
Normal file
54
comfy/ldm/newbie/components.py
Normal file
@ -0,0 +1,54 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as RMSNorm
|
||||
except ImportError:
|
||||
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
195
comfy/ldm/newbie/model.py
Normal file
195
comfy/ldm/newbie/model.py
Normal file
@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Any, Dict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy.ldm.common_dit as common_dit
|
||||
from comfy.ldm.lumina.model import NextDiT as NextDiTBase
|
||||
from .components import RMSNorm
|
||||
|
||||
#######################################################
|
||||
# Adds support for NewBie image #
|
||||
#######################################################
|
||||
|
||||
def _fallback_operations():
|
||||
try:
|
||||
import comfy.ops
|
||||
return comfy.ops.disable_weight_init
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _pop_unexpected_kwargs(kwargs: Dict[str, Any]) -> None:
|
||||
for k in (
|
||||
"model_type",
|
||||
"operation_settings",
|
||||
"unet_dtype",
|
||||
"weight_dtype",
|
||||
"precision",
|
||||
"extra_model_config",
|
||||
):
|
||||
kwargs.pop(k, None)
|
||||
|
||||
class NewBieNextDiT_CLIP(NextDiTBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
clip_text_dim: int = 1024,
|
||||
clip_img_dim: int = 1024,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
_pop_unexpected_kwargs(kwargs)
|
||||
if operations is None:
|
||||
operations = _fallback_operations()
|
||||
super().__init__(*args, device=device, dtype=dtype, operations=operations, **kwargs)
|
||||
self._nb_device = device
|
||||
self._nb_dtype = dtype
|
||||
self._nb_ops = operations
|
||||
min_mod = min(int(getattr(self, "dim", 1024)), 1024)
|
||||
if operations is not None and hasattr(operations, "Linear"):
|
||||
Linear = operations.Linear
|
||||
Norm = getattr(operations, "RMSNorm", None)
|
||||
else:
|
||||
Linear = nn.Linear
|
||||
Norm = None
|
||||
if Norm is not None:
|
||||
self.clip_text_pooled_proj = nn.Sequential(
|
||||
Norm(clip_text_dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype),
|
||||
Linear(clip_text_dim, clip_text_dim, bias=True, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.clip_text_pooled_proj = nn.Sequential(
|
||||
RMSNorm(clip_text_dim),
|
||||
nn.Linear(clip_text_dim, clip_text_dim, bias=True),
|
||||
)
|
||||
nn.init.normal_(self.clip_text_pooled_proj[1].weight, std=0.01)
|
||||
nn.init.zeros_(self.clip_text_pooled_proj[1].bias)
|
||||
self.time_text_embed = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
Linear(min_mod + clip_text_dim, min_mod, bias=True, device=device, dtype=dtype),
|
||||
)
|
||||
nn.init.zeros_(self.time_text_embed[1].weight)
|
||||
nn.init.zeros_(self.time_text_embed[1].bias)
|
||||
if Norm is not None:
|
||||
self.clip_img_pooled_embedder = nn.Sequential(
|
||||
Norm(clip_img_dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype),
|
||||
Linear(clip_img_dim, min_mod, bias=True, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.clip_img_pooled_embedder = nn.Sequential(
|
||||
RMSNorm(clip_img_dim),
|
||||
nn.Linear(clip_img_dim, min_mod, bias=True),
|
||||
)
|
||||
nn.init.normal_(self.clip_img_pooled_embedder[1].weight, std=0.01)
|
||||
nn.init.zeros_(self.clip_img_pooled_embedder[1].bias)
|
||||
|
||||
@staticmethod
|
||||
def _get_clip_from_kwargs(transformer_options: dict, kwargs: dict, key: str):
|
||||
if key in kwargs:
|
||||
return kwargs.get(key)
|
||||
if transformer_options is not None and key in transformer_options:
|
||||
return transformer_options.get(key)
|
||||
extra = transformer_options.get("extra_cond", None) if transformer_options else None
|
||||
if isinstance(extra, dict) and key in extra:
|
||||
return extra.get(key)
|
||||
return None
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
num_tokens: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
t = timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
bs, c, h, w = x.shape
|
||||
x = common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
t_emb = self.t_embedder(t, dtype=x.dtype)
|
||||
adaln_input = t_emb
|
||||
clip_text_pooled = self._get_clip_from_kwargs(transformer_options, kwargs, "clip_text_pooled")
|
||||
clip_img_pooled = self._get_clip_from_kwargs(transformer_options, kwargs, "clip_img_pooled")
|
||||
if clip_text_pooled is not None:
|
||||
if clip_text_pooled.dim() > 2:
|
||||
clip_text_pooled = clip_text_pooled.view(clip_text_pooled.shape[0], -1)
|
||||
clip_text_pooled = clip_text_pooled.to(device=t_emb.device, dtype=t_emb.dtype)
|
||||
clip_emb = self.clip_text_pooled_proj(clip_text_pooled)
|
||||
adaln_input = self.time_text_embed(torch.cat([t_emb, clip_emb], dim=-1))
|
||||
if clip_img_pooled is not None:
|
||||
if clip_img_pooled.dim() > 2:
|
||||
clip_img_pooled = clip_img_pooled.view(clip_img_pooled.shape[0], -1)
|
||||
clip_img_pooled = clip_img_pooled.to(device=t_emb.device, dtype=t_emb.dtype)
|
||||
adaln_input = adaln_input + self.clip_img_pooled_embedder(clip_img_pooled)
|
||||
if isinstance(cap_feats, torch.Tensor):
|
||||
try:
|
||||
target_dtype = next(self.cap_embedder.parameters()).dtype
|
||||
except StopIteration:
|
||||
target_dtype = cap_feats.dtype
|
||||
cap_feats = cap_feats.to(device=t_emb.device, dtype=target_dtype)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = True
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options
|
||||
)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
for i, layer in enumerate(self.layers):
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p(
|
||||
{
|
||||
"img": img[:, cap_size[0] :],
|
||||
"txt": img[:, : cap_size[0]],
|
||||
"pe": freqs_cis[:, cap_size[0] :],
|
||||
"vec": adaln_input,
|
||||
"x": x,
|
||||
"block_index": i,
|
||||
"transformer_options": transformer_options,
|
||||
}
|
||||
)
|
||||
if isinstance(out, dict):
|
||||
if "img" in out:
|
||||
img[:, cap_size[0] :] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, : cap_size[0]] = out["txt"]
|
||||
|
||||
img = self.final_layer(img, adaln_input)
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)
|
||||
img = img[:, :, :h, :w]
|
||||
return img
|
||||
|
||||
def NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP(**kwargs):
|
||||
_pop_unexpected_kwargs(kwargs)
|
||||
kwargs.setdefault("patch_size", 2)
|
||||
kwargs.setdefault("in_channels", 16)
|
||||
kwargs.setdefault("dim", 2304)
|
||||
kwargs.setdefault("n_layers", 36)
|
||||
kwargs.setdefault("n_heads", 24)
|
||||
kwargs.setdefault("n_kv_heads", 8)
|
||||
kwargs.setdefault("axes_dims", [32, 32, 32])
|
||||
kwargs.setdefault("axes_lens", [1024, 512, 512])
|
||||
return NewBieNextDiT_CLIP(**kwargs)
|
||||
|
||||
def NewBieNextDiT(*, device=None, dtype=None, operations=None, **kwargs):
|
||||
_pop_unexpected_kwargs(kwargs)
|
||||
if operations is None:
|
||||
operations = _fallback_operations()
|
||||
if dtype is None:
|
||||
dev_str = str(device) if device is not None else ""
|
||||
if dev_str.startswith("cuda") and torch.cuda.is_available():
|
||||
if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
model = NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP(
|
||||
device=device, dtype=dtype, operations=operations, **kwargs
|
||||
)
|
||||
return model
|
||||
@ -928,6 +928,90 @@ class Flux2(Flux):
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class NewBieImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
import comfy.ldm.newbie.model as nb
|
||||
super().__init__(model_config, model_type, device=device, unet_model=nb.NewBieNextDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out["c_crossattn"] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
|
||||
cap_feats = kwargs.get("cap_feats", None)
|
||||
if cap_feats is not None:
|
||||
out["cap_feats"] = comfy.conds.CONDRegular(cap_feats)
|
||||
cap_mask = kwargs.get("cap_mask", None)
|
||||
if cap_mask is not None:
|
||||
out["cap_mask"] = comfy.conds.CONDRegular(cap_mask)
|
||||
clip_text_pooled = kwargs.get("clip_text_pooled", None)
|
||||
if clip_text_pooled is not None:
|
||||
out["clip_text_pooled"] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
clip_img_pooled = kwargs.get("clip_img_pooled", None)
|
||||
if clip_img_pooled is not None:
|
||||
out["clip_img_pooled"] = comfy.conds.CONDRegular(clip_img_pooled)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = super().extra_conds_shapes(**kwargs)
|
||||
cap_feats = kwargs.get("cap_feats", None)
|
||||
if cap_feats is not None:
|
||||
out["cap_feats"] = list(cap_feats.shape)
|
||||
clip_text_pooled = kwargs.get("clip_text_pooled", None)
|
||||
if clip_text_pooled is not None:
|
||||
out["clip_text_pooled"] = list(clip_text_pooled.shape)
|
||||
clip_img_pooled = kwargs.get("clip_img_pooled", None)
|
||||
if clip_img_pooled is not None:
|
||||
out["clip_img_pooled"] = list(clip_img_pooled.shape)
|
||||
return out
|
||||
|
||||
def apply_model(
|
||||
self, x, t,
|
||||
c_concat=None, c_crossattn=None,
|
||||
control=None, transformer_options={}, **kwargs
|
||||
):
|
||||
sigma = t
|
||||
try:
|
||||
model_device = next(self.diffusion_model.parameters()).device
|
||||
except StopIteration:
|
||||
model_device = x.device
|
||||
x_in = x.to(device=model_device)
|
||||
sigma_in = sigma.to(device=model_device)
|
||||
xc = self.model_sampling.calculate_input(sigma_in, x_in)
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([xc] + [c_concat.to(device=model_device)], dim=1)
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
xc = xc.to(dtype=dtype)
|
||||
t_val = (1.0 - sigma_in).to(dtype=torch.float32)
|
||||
cap_feats = kwargs.get("cap_feats", kwargs.get("cross_attn", c_crossattn))
|
||||
cap_mask = kwargs.get("cap_mask", kwargs.get("attention_mask"))
|
||||
clip_text_pooled = kwargs.get("clip_text_pooled")
|
||||
clip_img_pooled = kwargs.get("clip_img_pooled")
|
||||
if cap_feats is not None:
|
||||
cap_feats = cap_feats.to(device=model_device, dtype=dtype)
|
||||
if cap_mask is None and cap_feats is not None:
|
||||
cap_mask = torch.ones(cap_feats.shape[:2], dtype=torch.bool, device=model_device)
|
||||
elif cap_mask is not None:
|
||||
cap_mask = cap_mask.to(device=model_device)
|
||||
if cap_mask.dtype != torch.bool:
|
||||
cap_mask = cap_mask != 0
|
||||
model_kwargs = {}
|
||||
if clip_text_pooled is not None:
|
||||
model_kwargs["clip_text_pooled"] = clip_text_pooled.to(device=model_device, dtype=dtype)
|
||||
if clip_img_pooled is not None:
|
||||
model_kwargs["clip_img_pooled"] = clip_img_pooled.to(device=model_device, dtype=dtype)
|
||||
model_output = self.diffusion_model(xc, t_val, cap_feats, cap_mask, **model_kwargs).float()
|
||||
model_output = -model_output
|
||||
denoised = self.model_sampling.calculate_denoised(sigma_in, model_output, x_in)
|
||||
if denoised.device != x.device:
|
||||
denoised = denoised.to(device=x.device)
|
||||
return denoised
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@ -1110,10 +1194,6 @@ class Lumina2(BaseModel):
|
||||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
|
||||
@ -6,6 +6,26 @@ import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
def is_newbie_unet_state_dict(state_dict, key_prefix):
|
||||
state_dict_keys = state_dict.keys()
|
||||
try:
|
||||
x_embed = state_dict[f"{key_prefix}x_embedder.weight"]
|
||||
final = state_dict[f"{key_prefix}final_layer.linear.weight"]
|
||||
except KeyError:
|
||||
return False
|
||||
if x_embed.ndim != 2:
|
||||
return False
|
||||
dim = x_embed.shape[0]
|
||||
patch_dim = x_embed.shape[1]
|
||||
if dim != 2304 or patch_dim != 64:
|
||||
return False
|
||||
if final.shape[0] != patch_dim or final.shape[1] != dim:
|
||||
return False
|
||||
n_layers = count_blocks(state_dict_keys, f"{key_prefix}layers." + "{}.")
|
||||
if n_layers != 36:
|
||||
return False
|
||||
return True
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
@ -411,7 +431,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 / NewBie image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
@ -422,6 +442,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["qk_norm"] = True
|
||||
|
||||
if dit_config["dim"] == 2304 and is_newbie_unet_state_dict(state_dict, key_prefix): # NewBie image
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [1024, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["model_type"] = "newbie_dit"
|
||||
dit_config["image_model"] = "NewBieImage"
|
||||
return dit_config
|
||||
|
||||
if dit_config["dim"] == 2304: # Original Lumina 2
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
@ -429,9 +459,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None:
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
||||
get_sampler = execute
|
||||
|
||||
|
||||
class SamplerSEEDS2(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||
],
|
||||
outputs=[io.Sampler.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||
sampler_name = "seeds_2"
|
||||
sampler = comfy.samplers.ksampler(
|
||||
sampler_name,
|
||||
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||
)
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
||||
SamplerDPMAdaptative,
|
||||
SamplerER_SDE,
|
||||
SamplerSASolver,
|
||||
SamplerSEEDS2,
|
||||
SplitSigmas,
|
||||
SplitSigmasDenoise,
|
||||
FlipSigmas,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user