ComfyUI/comfy/ldm/newbie/model.py

195 lines
8.0 KiB
Python

from __future__ import annotations
from typing import Optional, Any, Dict
import torch
import torch.nn as nn
import comfy.ldm.common_dit as common_dit
from comfy.ldm.lumina.model import NextDiT as NextDiTBase
from .components import RMSNorm
#######################################################
# Adds support for NewBie image #
#######################################################
def _fallback_operations():
try:
import comfy.ops
return comfy.ops.disable_weight_init
except Exception:
return None
def _pop_unexpected_kwargs(kwargs: Dict[str, Any]) -> None:
for k in (
"model_type",
"operation_settings",
"unet_dtype",
"weight_dtype",
"precision",
"extra_model_config",
):
kwargs.pop(k, None)
class NewBieNextDiT_CLIP(NextDiTBase):
def __init__(
self,
*args,
clip_text_dim: int = 1024,
clip_img_dim: int = 1024,
device=None,
dtype=None,
operations=None,
**kwargs,
):
_pop_unexpected_kwargs(kwargs)
if operations is None:
operations = _fallback_operations()
super().__init__(*args, device=device, dtype=dtype, operations=operations, **kwargs)
self._nb_device = device
self._nb_dtype = dtype
self._nb_ops = operations
min_mod = min(int(getattr(self, "dim", 1024)), 1024)
if operations is not None and hasattr(operations, "Linear"):
Linear = operations.Linear
Norm = getattr(operations, "RMSNorm", None)
else:
Linear = nn.Linear
Norm = None
if Norm is not None:
self.clip_text_pooled_proj = nn.Sequential(
Norm(clip_text_dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype),
Linear(clip_text_dim, clip_text_dim, bias=True, device=device, dtype=dtype),
)
else:
self.clip_text_pooled_proj = nn.Sequential(
RMSNorm(clip_text_dim),
nn.Linear(clip_text_dim, clip_text_dim, bias=True),
)
nn.init.normal_(self.clip_text_pooled_proj[1].weight, std=0.01)
nn.init.zeros_(self.clip_text_pooled_proj[1].bias)
self.time_text_embed = nn.Sequential(
nn.SiLU(),
Linear(min_mod + clip_text_dim, min_mod, bias=True, device=device, dtype=dtype),
)
nn.init.zeros_(self.time_text_embed[1].weight)
nn.init.zeros_(self.time_text_embed[1].bias)
if Norm is not None:
self.clip_img_pooled_embedder = nn.Sequential(
Norm(clip_img_dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype),
Linear(clip_img_dim, min_mod, bias=True, device=device, dtype=dtype),
)
else:
self.clip_img_pooled_embedder = nn.Sequential(
RMSNorm(clip_img_dim),
nn.Linear(clip_img_dim, min_mod, bias=True),
)
nn.init.normal_(self.clip_img_pooled_embedder[1].weight, std=0.01)
nn.init.zeros_(self.clip_img_pooled_embedder[1].bias)
@staticmethod
def _get_clip_from_kwargs(transformer_options: dict, kwargs: dict, key: str):
if key in kwargs:
return kwargs.get(key)
if transformer_options is not None and key in transformer_options:
return transformer_options.get(key)
extra = transformer_options.get("extra_cond", None) if transformer_options else None
if isinstance(extra, dict) and key in extra:
return extra.get(key)
return None
def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
num_tokens: int,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: dict = {},
**kwargs,
):
t = timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
t_emb = self.t_embedder(t, dtype=x.dtype)
adaln_input = t_emb
clip_text_pooled = self._get_clip_from_kwargs(transformer_options, kwargs, "clip_text_pooled")
clip_img_pooled = self._get_clip_from_kwargs(transformer_options, kwargs, "clip_img_pooled")
if clip_text_pooled is not None:
if clip_text_pooled.dim() > 2:
clip_text_pooled = clip_text_pooled.view(clip_text_pooled.shape[0], -1)
clip_text_pooled = clip_text_pooled.to(device=t_emb.device, dtype=t_emb.dtype)
clip_emb = self.clip_text_pooled_proj(clip_text_pooled)
adaln_input = self.time_text_embed(torch.cat([t_emb, clip_emb], dim=-1))
if clip_img_pooled is not None:
if clip_img_pooled.dim() > 2:
clip_img_pooled = clip_img_pooled.view(clip_img_pooled.shape[0], -1)
clip_img_pooled = clip_img_pooled.to(device=t_emb.device, dtype=t_emb.dtype)
adaln_input = adaln_input + self.clip_img_pooled_embedder(clip_img_pooled)
if isinstance(cap_feats, torch.Tensor):
try:
target_dtype = next(self.cap_embedder.parameters()).dtype
except StopIteration:
target_dtype = cap_feats.dtype
cap_feats = cap_feats.to(device=t_emb.device, dtype=target_dtype)
cap_feats = self.cap_embedder(cap_feats)
patches = transformer_options.get("patches", {})
x_is_tensor = True
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options
)
freqs_cis = freqs_cis.to(img.device)
for i, layer in enumerate(self.layers):
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p(
{
"img": img[:, cap_size[0] :],
"txt": img[:, : cap_size[0]],
"pe": freqs_cis[:, cap_size[0] :],
"vec": adaln_input,
"x": x,
"block_index": i,
"transformer_options": transformer_options,
}
)
if isinstance(out, dict):
if "img" in out:
img[:, cap_size[0] :] = out["img"]
if "txt" in out:
img[:, : cap_size[0]] = out["txt"]
img = self.final_layer(img, adaln_input)
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)
img = img[:, :, :h, :w]
return img
def NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP(**kwargs):
_pop_unexpected_kwargs(kwargs)
kwargs.setdefault("patch_size", 2)
kwargs.setdefault("in_channels", 16)
kwargs.setdefault("dim", 2304)
kwargs.setdefault("n_layers", 36)
kwargs.setdefault("n_heads", 24)
kwargs.setdefault("n_kv_heads", 8)
kwargs.setdefault("axes_dims", [32, 32, 32])
kwargs.setdefault("axes_lens", [1024, 512, 512])
return NewBieNextDiT_CLIP(**kwargs)
def NewBieNextDiT(*, device=None, dtype=None, operations=None, **kwargs):
_pop_unexpected_kwargs(kwargs)
if operations is None:
operations = _fallback_operations()
if dtype is None:
dev_str = str(device) if device is not None else ""
if dev_str.startswith("cuda") and torch.cuda.is_available():
if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
dtype = torch.bfloat16
else:
dtype = torch.float16
else:
dtype = torch.float32
model = NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP(
device=device, dtype=dtype, operations=operations, **kwargs
)
return model