mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Merge 49fef1697c into 55ebd287ee
This commit is contained in:
commit
a439638cdf
763
comfy/ldm/twinflow/model.py
Normal file
763
comfy/ldm/twinflow/model.py
Normal file
@ -0,0 +1,763 @@
|
|||||||
|
"""
|
||||||
|
TwinFlow-Z-Image custom model architecture for ComfyUI.
|
||||||
|
Based on the Lumina-Image 2.0 / Z-Image architecture.
|
||||||
|
Supports the unique dual timestep embedding architecture of TwinFlow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
from comfy.ldm.flux.math import apply_rope
|
||||||
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
|
def clamp_fp16(x):
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1))
|
||||||
|
|
||||||
|
|
||||||
|
class JointAttention(nn.Module):
|
||||||
|
"""Multi-head attention module with combined QKV weights."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: Optional[int],
|
||||||
|
qk_norm: bool,
|
||||||
|
out_bias: bool = False,
|
||||||
|
operation_settings={},
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
||||||
|
self.n_local_heads = n_heads
|
||||||
|
self.n_local_kv_heads = self.n_kv_heads
|
||||||
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
|
self.qkv = operation_settings.get("operations").Linear(
|
||||||
|
dim,
|
||||||
|
(n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.out = operation_settings.get("operations").Linear(
|
||||||
|
n_heads * self.head_dim,
|
||||||
|
dim,
|
||||||
|
bias=out_bias,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if qk_norm:
|
||||||
|
self.q_norm = operation_settings.get("operations").RMSNorm(
|
||||||
|
self.head_dim,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.k_norm = operation_settings.get("operations").RMSNorm(
|
||||||
|
self.head_dim,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_norm = self.k_norm = nn.Identity()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
transformer_options={},
|
||||||
|
) -> torch.Tensor:
|
||||||
|
bsz, seqlen, _ = x.shape
|
||||||
|
|
||||||
|
xq, xk, xv = torch.split(
|
||||||
|
self.qkv(x),
|
||||||
|
[
|
||||||
|
self.n_local_heads * self.head_dim,
|
||||||
|
self.n_local_kv_heads * self.head_dim,
|
||||||
|
self.n_local_kv_heads * self.head_dim,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||||
|
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||||
|
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||||
|
|
||||||
|
xq = self.q_norm(xq)
|
||||||
|
xk = self.k_norm(xk)
|
||||||
|
|
||||||
|
xq, xk = apply_rope(xq, xk, freqs_cis)
|
||||||
|
|
||||||
|
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
|
if n_rep > 1:
|
||||||
|
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
|
|
||||||
|
output = optimized_attention_masked(
|
||||||
|
xq.movedim(1, 2),
|
||||||
|
xk.movedim(1, 2),
|
||||||
|
xv.movedim(1, 2),
|
||||||
|
self.n_local_heads,
|
||||||
|
x_mask,
|
||||||
|
skip_reshape=True,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.out(output)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
"""Feed-forward module with SiLU gating."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float],
|
||||||
|
operation_settings={},
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.w1 = operation_settings.get("operations").Linear(
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
bias=False,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.w2 = operation_settings.get("operations").Linear(
|
||||||
|
hidden_dim,
|
||||||
|
dim,
|
||||||
|
bias=False,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.w3 = operation_settings.get("operations").Linear(
|
||||||
|
dim,
|
||||||
|
hidden_dim,
|
||||||
|
bias=False,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward_silu_gating(self, x1, x3):
|
||||||
|
return clamp_fp16(F.silu(x1) * x3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class TwinFlowTransformerBlock(nn.Module):
|
||||||
|
"""Transformer block with adaLN modulation for TwinFlow."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: float,
|
||||||
|
norm_eps: float,
|
||||||
|
qk_norm: bool,
|
||||||
|
modulation=False,
|
||||||
|
z_image_modulation=False,
|
||||||
|
attn_out_bias=False,
|
||||||
|
operation_settings={},
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
self.attention = JointAttention(
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
qk_norm,
|
||||||
|
out_bias=attn_out_bias,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
self.feed_forward = FeedForward(
|
||||||
|
dim=dim,
|
||||||
|
hidden_dim=dim,
|
||||||
|
multiple_of=multiple_of,
|
||||||
|
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.attention_norm1 = operation_settings.get("operations").RMSNorm(
|
||||||
|
dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(
|
||||||
|
dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.attention_norm2 = operation_settings.get("operations").RMSNorm(
|
||||||
|
dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(
|
||||||
|
dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if z_image_modulation:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
min(dim, 256),
|
||||||
|
4 * dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_mask: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
adaln_input: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options={},
|
||||||
|
):
|
||||||
|
if adaln_input is not None:
|
||||||
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||||
|
|
||||||
|
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||||
|
clamp_fp16(
|
||||||
|
self.attention(
|
||||||
|
modulate(self.attention_norm1(x), scale_msa),
|
||||||
|
x_mask,
|
||||||
|
freqs_cis,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||||
|
clamp_fp16(
|
||||||
|
self.feed_forward(
|
||||||
|
modulate(self.ffn_norm1(x), scale_mlp),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = x + self.attention_norm2(
|
||||||
|
clamp_fp16(
|
||||||
|
self.attention(
|
||||||
|
self.attention_norm1(x),
|
||||||
|
x_mask,
|
||||||
|
freqs_cis,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""Final layer with LayerNorm and output projection."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
out_channels: int,
|
||||||
|
z_image_modulation=False,
|
||||||
|
operation_settings={},
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operation_settings.get("operations").LayerNorm(
|
||||||
|
hidden_size,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
self.linear = operation_settings.get("operations").Linear(
|
||||||
|
hidden_size,
|
||||||
|
patch_size * patch_size * out_channels,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
|
||||||
|
min_mod = 256 if z_image_modulation else 1024
|
||||||
|
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
min(hidden_size, min_mod),
|
||||||
|
hidden_size,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, c: torch.Tensor):
|
||||||
|
scale = self.adaLN_modulation(c)
|
||||||
|
x = modulate(self.norm_final(x), scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TwinFlowZImageTransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
TwinFlow-Z-Image transformer model.
|
||||||
|
|
||||||
|
This custom architecture handles dual timestep embeddings
|
||||||
|
(t_embedder and t_embedder_2), the primary TwinFlow distinction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 16,
|
||||||
|
dim: int = 3840,
|
||||||
|
n_layers: int = 30,
|
||||||
|
n_refiner_layers: int = 2,
|
||||||
|
n_heads: int = 30,
|
||||||
|
n_kv_heads: Optional[int] = None,
|
||||||
|
multiple_of: int = 256,
|
||||||
|
ffn_dim_multiplier: float = 2.6666666666666665,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
cap_feat_dim: int = 2560,
|
||||||
|
axes_dims: List[int] = (32, 48, 48),
|
||||||
|
axes_lens: List[int] = (1, 1536, 512, 512),
|
||||||
|
rope_theta: float = 256.0,
|
||||||
|
z_image_modulation: bool = True,
|
||||||
|
time_scale: float = 1000.0,
|
||||||
|
pad_tokens_multiple=None,
|
||||||
|
clip_text_dim=None,
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
operation_settings = {
|
||||||
|
"operations": operations,
|
||||||
|
"device": device,
|
||||||
|
"dtype": dtype,
|
||||||
|
}
|
||||||
|
self.time_embed_dim = 256 if z_image_modulation else min(dim, 1024)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.time_scale = time_scale
|
||||||
|
self.pad_tokens_multiple = pad_tokens_multiple
|
||||||
|
|
||||||
|
self.x_embedder = operation_settings.get("operations").Linear(
|
||||||
|
in_features=patch_size * patch_size * in_channels,
|
||||||
|
out_features=dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
min(dim, 1024),
|
||||||
|
output_size=self.time_embed_dim if z_image_modulation else None,
|
||||||
|
**operation_settings,
|
||||||
|
)
|
||||||
|
self.t_embedder_2 = TimestepEmbedder(
|
||||||
|
min(dim, 1024),
|
||||||
|
output_size=self.time_embed_dim if z_image_modulation else None,
|
||||||
|
**operation_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TwinFlowTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
z_image_modulation=z_image_modulation,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.context_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TwinFlowTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=False,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cap_embedder = nn.Sequential(
|
||||||
|
operation_settings.get("operations").RMSNorm(
|
||||||
|
cap_feat_dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
cap_feat_dim,
|
||||||
|
dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.clip_text_pooled_proj = None
|
||||||
|
if clip_text_dim is not None:
|
||||||
|
self.clip_text_dim = clip_text_dim
|
||||||
|
self.clip_text_pooled_proj = nn.Sequential(
|
||||||
|
operation_settings.get("operations").RMSNorm(
|
||||||
|
clip_text_dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
clip_text_dim,
|
||||||
|
clip_text_dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.clip_text_concat_proj = nn.Sequential(
|
||||||
|
operation_settings.get("operations").RMSNorm(
|
||||||
|
clip_text_dim + self.time_embed_dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
clip_text_dim + self.time_embed_dim,
|
||||||
|
self.time_embed_dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TwinFlowTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
z_image_modulation=z_image_modulation,
|
||||||
|
attn_out_bias=False,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(
|
||||||
|
dim,
|
||||||
|
patch_size,
|
||||||
|
self.out_channels,
|
||||||
|
z_image_modulation=z_image_modulation,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.pad_tokens_multiple is not None:
|
||||||
|
self.x_pad_token = nn.Parameter(torch.zeros((1, dim), device=device, dtype=dtype))
|
||||||
|
self.cap_pad_token = nn.Parameter(torch.zeros((1, dim), device=device, dtype=dtype))
|
||||||
|
|
||||||
|
assert (dim // n_heads) == sum(axes_dims)
|
||||||
|
self.axes_dims = axes_dims
|
||||||
|
self.axes_lens = axes_lens
|
||||||
|
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
|
||||||
|
self.dim = dim
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
def _compute_twinflow_adaln(self, t: torch.Tensor, x_dtype: torch.dtype, transformer_options={}):
|
||||||
|
"""
|
||||||
|
Compute TwinFlow adaLN input.
|
||||||
|
|
||||||
|
If `target_timestep` is provided in transformer options, apply the
|
||||||
|
TwinFlow delta-time conditioning:
|
||||||
|
t_emb + t_embedder_2((target - t) * time_scale) * abs(target - t)
|
||||||
|
otherwise fallback to the baseline additive embedding.
|
||||||
|
"""
|
||||||
|
t_emb = self.t_embedder(t * self.time_scale, dtype=x_dtype)
|
||||||
|
target_timestep = transformer_options.get("target_timestep", None)
|
||||||
|
if target_timestep is None:
|
||||||
|
t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x_dtype)
|
||||||
|
return t_emb + t_emb_2
|
||||||
|
|
||||||
|
target_t = torch.as_tensor(target_timestep, device=t.device, dtype=t.dtype)
|
||||||
|
if target_t.ndim == 0:
|
||||||
|
target_t = target_t.expand_as(t)
|
||||||
|
|
||||||
|
# If values look scaled (roughly sigma/timestep in [0..1000]), normalize.
|
||||||
|
t_abs_max = float(t.detach().abs().max().item()) if t.numel() else 0.0
|
||||||
|
tt_abs_max = float(target_t.detach().abs().max().item()) if target_t.numel() else 0.0
|
||||||
|
scaled_domain = (max(t_abs_max, tt_abs_max) > 2.0) and (self.time_scale > 2.0)
|
||||||
|
if scaled_domain:
|
||||||
|
t_norm = t / self.time_scale
|
||||||
|
tt_norm = target_t / self.time_scale
|
||||||
|
else:
|
||||||
|
t_norm = t
|
||||||
|
tt_norm = target_t
|
||||||
|
|
||||||
|
delta_abs = (t_norm - tt_norm).abs().unsqueeze(1).to(t_emb.dtype)
|
||||||
|
diff_in = (tt_norm - t_norm) * self.time_scale
|
||||||
|
t_emb_2 = self.t_embedder_2(diff_in, dtype=x_dtype)
|
||||||
|
return t_emb + t_emb_2 * delta_abs
|
||||||
|
|
||||||
|
def unpatchify(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
img_size: List[Tuple[int, int]],
|
||||||
|
cap_size: List[int],
|
||||||
|
return_tensor=False,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
pH = pW = self.patch_size
|
||||||
|
imgs = []
|
||||||
|
for i in range(x.size(0)):
|
||||||
|
H, W = img_size[i]
|
||||||
|
begin = cap_size[i]
|
||||||
|
end = begin + (H // pH) * (W // pW)
|
||||||
|
imgs.append(
|
||||||
|
x[i][begin:end]
|
||||||
|
.view(H // pH, W // pW, pH, pW, self.out_channels)
|
||||||
|
.permute(4, 0, 2, 1, 3)
|
||||||
|
.flatten(3, 4)
|
||||||
|
.flatten(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_tensor:
|
||||||
|
imgs = torch.stack(imgs, dim=0)
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
def patchify_and_embed(
|
||||||
|
self,
|
||||||
|
x: List[torch.Tensor] | torch.Tensor,
|
||||||
|
cap_feats: torch.Tensor,
|
||||||
|
cap_mask: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
num_tokens,
|
||||||
|
transformer_options={},
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||||
|
bsz = len(x)
|
||||||
|
pH = pW = self.patch_size
|
||||||
|
device = x[0].device
|
||||||
|
|
||||||
|
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||||
|
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||||
|
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x = self.x_embedder(
|
||||||
|
x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
rope_options = transformer_options.get("rope_options", {})
|
||||||
|
h_scale = rope_options.get("scale_y", 1.0)
|
||||||
|
w_scale = rope_options.get("scale_x", 1.0)
|
||||||
|
h_start = rope_options.get("shift_y", 0.0)
|
||||||
|
w_start = rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
H_tokens, W_tokens = H // pH, W // pW
|
||||||
|
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||||
|
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||||
|
x_pos_ids[:, :, 1] = (
|
||||||
|
torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start
|
||||||
|
).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||||
|
x_pos_ids[:, :, 2] = (
|
||||||
|
torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start
|
||||||
|
).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||||
|
|
||||||
|
x_pad_extra = 0
|
||||||
|
if self.pad_tokens_multiple is not None:
|
||||||
|
x_pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||||
|
x = torch.cat(
|
||||||
|
(
|
||||||
|
x,
|
||||||
|
self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], x_pad_extra, 1),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, x_pad_extra))
|
||||||
|
|
||||||
|
cap_pad_extra = 0
|
||||||
|
if self.pad_tokens_multiple is not None:
|
||||||
|
cap_pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||||
|
cap_feats = torch.cat(
|
||||||
|
(
|
||||||
|
cap_feats,
|
||||||
|
self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(cap_feats.shape[0], cap_pad_extra, 1),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, cap_pad_extra), value=0)
|
||||||
|
if cap_mask is not None and cap_pad_extra > 0:
|
||||||
|
cap_mask = torch.nn.functional.pad(cap_mask, (0, cap_pad_extra), value=0)
|
||||||
|
|
||||||
|
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||||
|
|
||||||
|
for layer in self.context_refiner:
|
||||||
|
cap_feats = layer(
|
||||||
|
cap_feats,
|
||||||
|
cap_mask,
|
||||||
|
freqs_cis[:, : cap_pos_ids.shape[1]],
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
padded_img_mask = None
|
||||||
|
for _, layer in enumerate(self.noise_refiner):
|
||||||
|
x = layer(
|
||||||
|
x,
|
||||||
|
padded_img_mask,
|
||||||
|
freqs_cis[:, cap_pos_ids.shape[1] :],
|
||||||
|
t,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||||
|
if cap_mask is not None:
|
||||||
|
cap_mask_bool = cap_mask if cap_mask.dtype == torch.bool else cap_mask > 0
|
||||||
|
img_mask = torch.ones((bsz, x.shape[1]), device=cap_mask.device, dtype=torch.bool)
|
||||||
|
if x_pad_extra > 0:
|
||||||
|
img_mask[:, -x_pad_extra:] = False
|
||||||
|
mask = torch.cat((cap_mask_bool, img_mask), dim=1)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
img_sizes = [(H, W)] * bsz
|
||||||
|
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||||
|
|
||||||
|
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(
|
||||||
|
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL,
|
||||||
|
kwargs.get("transformer_options", {}),
|
||||||
|
),
|
||||||
|
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
num_tokens,
|
||||||
|
attention_mask=None,
|
||||||
|
transformer_options=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
|
|
||||||
|
t = 1.0 - timesteps
|
||||||
|
|
||||||
|
adaln_input = self._compute_twinflow_adaln(t, x.dtype, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
cap_feats = context
|
||||||
|
cap_mask = attention_mask
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
|
|
||||||
|
cap_feats = self.cap_embedder(cap_feats)
|
||||||
|
|
||||||
|
if self.clip_text_pooled_proj is not None:
|
||||||
|
pooled = kwargs.get("clip_text_pooled", None)
|
||||||
|
if pooled is not None:
|
||||||
|
pooled = self.clip_text_pooled_proj(pooled)
|
||||||
|
else:
|
||||||
|
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||||
|
adaln_input = torch.cat((adaln_input, pooled), dim=-1)
|
||||||
|
adaln_input = self.clip_text_concat_proj(adaln_input)
|
||||||
|
|
||||||
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
|
||||||
|
x,
|
||||||
|
cap_feats,
|
||||||
|
cap_mask,
|
||||||
|
adaln_input,
|
||||||
|
num_tokens,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
|
transformer_options["total_blocks"] = len(self.layers)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
img = self.final_layer(img, adaln_input)
|
||||||
|
img = self.unpatchify(
|
||||||
|
img,
|
||||||
|
img_size,
|
||||||
|
cap_size,
|
||||||
|
return_tensor=isinstance(x, torch.Tensor),
|
||||||
|
)[:, :, :h, :w]
|
||||||
|
|
||||||
|
return -img
|
||||||
@ -328,6 +328,15 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
# TwinFlow-Z-Image LoRAs can target t_embedder_2.* keys.
|
||||||
|
# Alias them back to t_embedder.* targets for compatibility.
|
||||||
|
if isinstance(model, comfy.model_base.TwinFlow_Z_Image):
|
||||||
|
for key in list(key_map.keys()):
|
||||||
|
if "t_embedder." in key and "t_embedder_2." not in key:
|
||||||
|
key_2 = key.replace("t_embedder.", "t_embedder_2.", 1)
|
||||||
|
if key_2 not in key_map:
|
||||||
|
key_map[key_2] = key_map[key]
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
|||||||
@ -32,7 +32,15 @@ def convert_uso_lora(sd):
|
|||||||
sd_out[k_to] = tensor
|
sd_out[k_to] = tensor
|
||||||
return sd_out
|
return sd_out
|
||||||
|
|
||||||
|
def twinflow_z_image_lora_to_diffusers(state_dict):
|
||||||
|
"""Convert TwinFlow LoRA state dict for diffusers compatibility."""
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if "t_embedder_2" not in key and key.startswith("t_embedder."):
|
||||||
|
new_key = key.replace("t_embedder.", "t_embedder_2.", 1)
|
||||||
|
if new_key not in state_dict:
|
||||||
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def convert_lora(sd):
|
def convert_lora(sd):
|
||||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||||
return convert_lora_bfl_control(sd)
|
return convert_lora_bfl_control(sd)
|
||||||
@ -40,4 +48,6 @@ def convert_lora(sd):
|
|||||||
return convert_lora_wan_fun(sd)
|
return convert_lora_wan_fun(sd)
|
||||||
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
|
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
|
||||||
return convert_uso_lora(sd)
|
return convert_uso_lora(sd)
|
||||||
|
if any(k.startswith("t_embedder.") for k in sd.keys()):
|
||||||
|
return twinflow_z_image_lora_to_diffusers(sd)
|
||||||
return sd
|
return sd
|
||||||
|
|||||||
@ -40,6 +40,7 @@ import comfy.ldm.hunyuan_video.model
|
|||||||
import comfy.ldm.cosmos.model
|
import comfy.ldm.cosmos.model
|
||||||
import comfy.ldm.cosmos.predict2
|
import comfy.ldm.cosmos.predict2
|
||||||
import comfy.ldm.lumina.model
|
import comfy.ldm.lumina.model
|
||||||
|
import comfy.ldm.twinflow.model
|
||||||
import comfy.ldm.wan.model
|
import comfy.ldm.wan.model
|
||||||
import comfy.ldm.wan.model_animate
|
import comfy.ldm.wan.model_animate
|
||||||
import comfy.ldm.hunyuan3d.model
|
import comfy.ldm.hunyuan3d.model
|
||||||
@ -1281,6 +1282,11 @@ class ZImagePixelSpace(Lumina2):
|
|||||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||||
self.memory_usage_factor_conds = ("ref_latents",)
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
|
class TwinFlow_Z_Image(Lumina2):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.twinflow.model.TwinFlowZImageTransformer)
|
||||||
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
|||||||
@ -44,6 +44,48 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
|||||||
def detect_unet_config(state_dict, key_prefix, metadata=None):
|
def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
|
# TwinFlow-Z-Image: detect dual timestep embedder checkpoints first.
|
||||||
|
if any(k.startswith('{}t_embedder_2.'.format(key_prefix)) for k in state_dict_keys):
|
||||||
|
dit_config = {
|
||||||
|
"image_model": "twinflow_z_image",
|
||||||
|
"architecture": "TwinFlow_Z_Image",
|
||||||
|
"patch_size": 2,
|
||||||
|
"in_channels": 16,
|
||||||
|
"qk_norm": True,
|
||||||
|
"ffn_dim_multiplier": (8.0 / 3.0),
|
||||||
|
"z_image_modulation": True,
|
||||||
|
"time_scale": 1000.0,
|
||||||
|
"n_refiner_layers": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
cap_embedder_key = '{}cap_embedder.1.weight'.format(key_prefix)
|
||||||
|
if cap_embedder_key in state_dict:
|
||||||
|
w = state_dict[cap_embedder_key]
|
||||||
|
dit_config["dim"] = w.shape[0]
|
||||||
|
dit_config["cap_feat_dim"] = w.shape[1]
|
||||||
|
|
||||||
|
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||||
|
|
||||||
|
# Match Z-Image style defaults (TwinFlow checkpoints are 3840-dim variants).
|
||||||
|
dit_config["n_heads"] = 30
|
||||||
|
dit_config["n_kv_heads"] = 30
|
||||||
|
dit_config["axes_dims"] = [32, 48, 48]
|
||||||
|
dit_config["axes_lens"] = [1536, 512, 512]
|
||||||
|
dit_config["rope_theta"] = 256.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
dit_config["allow_fp16"] = torch.std(
|
||||||
|
state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)],
|
||||||
|
unbiased=False
|
||||||
|
).item() < 0.42
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys or '{}x_pad_token'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["pad_tokens_multiple"] = 32
|
||||||
|
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||||
unet_config = {}
|
unet_config = {}
|
||||||
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
||||||
|
|||||||
14
comfy/sd.py
14
comfy/sd.py
@ -74,6 +74,20 @@ import comfy.latent_formats
|
|||||||
|
|
||||||
import comfy.ldm.flux.redux
|
import comfy.ldm.flux.redux
|
||||||
|
|
||||||
|
def is_twinflow_z_image_model(state_dict):
|
||||||
|
"""Check if model state dict is TwinFlow-Z-Image."""
|
||||||
|
return any(k.startswith("t_embedder_2.") for k in state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def get_twinflow_z_image_config(state_dict):
|
||||||
|
"""Extract TwinFlow-Z-Image configuration from state dict."""
|
||||||
|
if not is_twinflow_z_image_model(state_dict):
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
"image_model": "twinflow_z_image",
|
||||||
|
"architecture": "TwinFlow_Z_Image",
|
||||||
|
}
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if model is not None:
|
if model is not None:
|
||||||
|
|||||||
@ -1132,6 +1132,15 @@ class ZImagePixelSpace(ZImage):
|
|||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.ZImagePixelSpace(self, device=device)
|
return model_base.ZImagePixelSpace(self, device=device)
|
||||||
|
|
||||||
|
class TwinFlow_Z_Image(ZImage):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "twinflow_z_image",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.TwinFlow_Z_Image(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN21_T2V(supported_models_base.BASE):
|
class WAN21_T2V(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1749,6 +1758,6 @@ class RT_DETR_v4(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, TwinFlow_Z_Image, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -818,6 +818,17 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
def twinflow_z_image_key_mapping(state_dict, key):
|
||||||
|
"""
|
||||||
|
TwinFlow-Z-Image key mapping.
|
||||||
|
Maps t_embedder_2 keys to t_embedder for weight loading.
|
||||||
|
"""
|
||||||
|
if key.startswith("t_embedder_2."):
|
||||||
|
new_key = key.replace("t_embedder_2.", "t_embedder.", 1)
|
||||||
|
if new_key not in state_dict:
|
||||||
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
return state_dic
|
||||||
|
|
||||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||||
if tensor.shape[dim] > batch_size:
|
if tensor.shape[dim] > batch_size:
|
||||||
return tensor.narrow(dim, 0, batch_size)
|
return tensor.narrow(dim, 0, batch_size)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user