mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
195 lines
8.0 KiB
Python
195 lines
8.0 KiB
Python
from __future__ import annotations
|
|
from typing import Optional, Any, Dict
|
|
import torch
|
|
import torch.nn as nn
|
|
import comfy.ldm.common_dit as common_dit
|
|
from comfy.ldm.lumina.model import NextDiT as NextDiTBase
|
|
from .components import RMSNorm
|
|
|
|
#######################################################
|
|
# Adds support for NewBie image #
|
|
#######################################################
|
|
|
|
def _fallback_operations():
|
|
try:
|
|
import comfy.ops
|
|
return comfy.ops.disable_weight_init
|
|
except Exception:
|
|
return None
|
|
|
|
def _pop_unexpected_kwargs(kwargs: Dict[str, Any]) -> None:
|
|
for k in (
|
|
"model_type",
|
|
"operation_settings",
|
|
"unet_dtype",
|
|
"weight_dtype",
|
|
"precision",
|
|
"extra_model_config",
|
|
):
|
|
kwargs.pop(k, None)
|
|
|
|
class NewBieNextDiT_CLIP(NextDiTBase):
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
clip_text_dim: int = 1024,
|
|
clip_img_dim: int = 1024,
|
|
device=None,
|
|
dtype=None,
|
|
operations=None,
|
|
**kwargs,
|
|
):
|
|
_pop_unexpected_kwargs(kwargs)
|
|
if operations is None:
|
|
operations = _fallback_operations()
|
|
super().__init__(*args, device=device, dtype=dtype, operations=operations, **kwargs)
|
|
self._nb_device = device
|
|
self._nb_dtype = dtype
|
|
self._nb_ops = operations
|
|
min_mod = min(int(getattr(self, "dim", 1024)), 1024)
|
|
if operations is not None and hasattr(operations, "Linear"):
|
|
Linear = operations.Linear
|
|
Norm = getattr(operations, "RMSNorm", None)
|
|
else:
|
|
Linear = nn.Linear
|
|
Norm = None
|
|
if Norm is not None:
|
|
self.clip_text_pooled_proj = nn.Sequential(
|
|
Norm(clip_text_dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype),
|
|
Linear(clip_text_dim, clip_text_dim, bias=True, device=device, dtype=dtype),
|
|
)
|
|
else:
|
|
self.clip_text_pooled_proj = nn.Sequential(
|
|
RMSNorm(clip_text_dim),
|
|
nn.Linear(clip_text_dim, clip_text_dim, bias=True),
|
|
)
|
|
nn.init.normal_(self.clip_text_pooled_proj[1].weight, std=0.01)
|
|
nn.init.zeros_(self.clip_text_pooled_proj[1].bias)
|
|
self.time_text_embed = nn.Sequential(
|
|
nn.SiLU(),
|
|
Linear(min_mod + clip_text_dim, min_mod, bias=True, device=device, dtype=dtype),
|
|
)
|
|
nn.init.zeros_(self.time_text_embed[1].weight)
|
|
nn.init.zeros_(self.time_text_embed[1].bias)
|
|
if Norm is not None:
|
|
self.clip_img_pooled_embedder = nn.Sequential(
|
|
Norm(clip_img_dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype),
|
|
Linear(clip_img_dim, min_mod, bias=True, device=device, dtype=dtype),
|
|
)
|
|
else:
|
|
self.clip_img_pooled_embedder = nn.Sequential(
|
|
RMSNorm(clip_img_dim),
|
|
nn.Linear(clip_img_dim, min_mod, bias=True),
|
|
)
|
|
nn.init.normal_(self.clip_img_pooled_embedder[1].weight, std=0.01)
|
|
nn.init.zeros_(self.clip_img_pooled_embedder[1].bias)
|
|
|
|
@staticmethod
|
|
def _get_clip_from_kwargs(transformer_options: dict, kwargs: dict, key: str):
|
|
if key in kwargs:
|
|
return kwargs.get(key)
|
|
if transformer_options is not None and key in transformer_options:
|
|
return transformer_options.get(key)
|
|
extra = transformer_options.get("extra_cond", None) if transformer_options else None
|
|
if isinstance(extra, dict) and key in extra:
|
|
return extra.get(key)
|
|
return None
|
|
def _forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
timesteps: torch.Tensor,
|
|
context: torch.Tensor,
|
|
num_tokens: int,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
transformer_options: dict = {},
|
|
**kwargs,
|
|
):
|
|
t = timesteps
|
|
cap_feats = context
|
|
cap_mask = attention_mask
|
|
bs, c, h, w = x.shape
|
|
x = common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
|
t_emb = self.t_embedder(t, dtype=x.dtype)
|
|
adaln_input = t_emb
|
|
clip_text_pooled = self._get_clip_from_kwargs(transformer_options, kwargs, "clip_text_pooled")
|
|
clip_img_pooled = self._get_clip_from_kwargs(transformer_options, kwargs, "clip_img_pooled")
|
|
if clip_text_pooled is not None:
|
|
if clip_text_pooled.dim() > 2:
|
|
clip_text_pooled = clip_text_pooled.view(clip_text_pooled.shape[0], -1)
|
|
clip_text_pooled = clip_text_pooled.to(device=t_emb.device, dtype=t_emb.dtype)
|
|
clip_emb = self.clip_text_pooled_proj(clip_text_pooled)
|
|
adaln_input = self.time_text_embed(torch.cat([t_emb, clip_emb], dim=-1))
|
|
if clip_img_pooled is not None:
|
|
if clip_img_pooled.dim() > 2:
|
|
clip_img_pooled = clip_img_pooled.view(clip_img_pooled.shape[0], -1)
|
|
clip_img_pooled = clip_img_pooled.to(device=t_emb.device, dtype=t_emb.dtype)
|
|
adaln_input = adaln_input + self.clip_img_pooled_embedder(clip_img_pooled)
|
|
if isinstance(cap_feats, torch.Tensor):
|
|
try:
|
|
target_dtype = next(self.cap_embedder.parameters()).dtype
|
|
except StopIteration:
|
|
target_dtype = cap_feats.dtype
|
|
cap_feats = cap_feats.to(device=t_emb.device, dtype=target_dtype)
|
|
cap_feats = self.cap_embedder(cap_feats)
|
|
patches = transformer_options.get("patches", {})
|
|
x_is_tensor = True
|
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
|
|
x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options
|
|
)
|
|
freqs_cis = freqs_cis.to(img.device)
|
|
for i, layer in enumerate(self.layers):
|
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
|
if "double_block" in patches:
|
|
for p in patches["double_block"]:
|
|
out = p(
|
|
{
|
|
"img": img[:, cap_size[0] :],
|
|
"txt": img[:, : cap_size[0]],
|
|
"pe": freqs_cis[:, cap_size[0] :],
|
|
"vec": adaln_input,
|
|
"x": x,
|
|
"block_index": i,
|
|
"transformer_options": transformer_options,
|
|
}
|
|
)
|
|
if isinstance(out, dict):
|
|
if "img" in out:
|
|
img[:, cap_size[0] :] = out["img"]
|
|
if "txt" in out:
|
|
img[:, : cap_size[0]] = out["txt"]
|
|
|
|
img = self.final_layer(img, adaln_input)
|
|
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)
|
|
img = img[:, :, :h, :w]
|
|
return img
|
|
|
|
def NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP(**kwargs):
|
|
_pop_unexpected_kwargs(kwargs)
|
|
kwargs.setdefault("patch_size", 2)
|
|
kwargs.setdefault("in_channels", 16)
|
|
kwargs.setdefault("dim", 2304)
|
|
kwargs.setdefault("n_layers", 36)
|
|
kwargs.setdefault("n_heads", 24)
|
|
kwargs.setdefault("n_kv_heads", 8)
|
|
kwargs.setdefault("axes_dims", [32, 32, 32])
|
|
kwargs.setdefault("axes_lens", [1024, 512, 512])
|
|
return NewBieNextDiT_CLIP(**kwargs)
|
|
|
|
def NewBieNextDiT(*, device=None, dtype=None, operations=None, **kwargs):
|
|
_pop_unexpected_kwargs(kwargs)
|
|
if operations is None:
|
|
operations = _fallback_operations()
|
|
if dtype is None:
|
|
dev_str = str(device) if device is not None else ""
|
|
if dev_str.startswith("cuda") and torch.cuda.is_available():
|
|
if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
|
|
dtype = torch.bfloat16
|
|
else:
|
|
dtype = torch.float16
|
|
else:
|
|
dtype = torch.float32
|
|
model = NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP(
|
|
device=device, dtype=dtype, operations=operations, **kwargs
|
|
)
|
|
return model |