mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-18 14:32:49 +08:00
structure model
This commit is contained in:
parent
d6573fd26d
commit
6624939505
@ -4,6 +4,73 @@ from comfy.ldm.modules.attention import optimized_attention
|
|||||||
from typing import Tuple, Union, List
|
from typing import Tuple, Union, List
|
||||||
from vae import VarLenTensor
|
from vae import VarLenTensor
|
||||||
|
|
||||||
|
FLASH_ATTN_3_AVA = True
|
||||||
|
try:
|
||||||
|
import flash_attn_interface as flash_attn_3
|
||||||
|
except:
|
||||||
|
FLASH_ATTN_3_AVA = False
|
||||||
|
|
||||||
|
# TODO repalce with optimized attention
|
||||||
|
def scaled_dot_product_attention(*args, **kwargs):
|
||||||
|
num_all_args = len(args) + len(kwargs)
|
||||||
|
|
||||||
|
if num_all_args == 1:
|
||||||
|
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||||
|
|
||||||
|
elif num_all_args == 2:
|
||||||
|
q = args[0] if len(args) > 0 else kwargs['q']
|
||||||
|
kv = args[1] if len(args) > 1 else kwargs['kv']
|
||||||
|
|
||||||
|
elif num_all_args == 3:
|
||||||
|
q = args[0] if len(args) > 0 else kwargs['q']
|
||||||
|
k = args[1] if len(args) > 1 else kwargs['k']
|
||||||
|
v = args[2] if len(args) > 2 else kwargs['v']
|
||||||
|
|
||||||
|
if optimized_attention.__name__ == 'attention_xformers':
|
||||||
|
if 'xops' not in globals():
|
||||||
|
import xformers.ops as xops
|
||||||
|
if num_all_args == 1:
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
elif num_all_args == 2:
|
||||||
|
k, v = kv.unbind(dim=2)
|
||||||
|
out = xops.memory_efficient_attention(q, k, v)
|
||||||
|
elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA:
|
||||||
|
if 'flash_attn' not in globals():
|
||||||
|
import flash_attn
|
||||||
|
if num_all_args == 2:
|
||||||
|
out = flash_attn.flash_attn_kvpacked_func(q, kv)
|
||||||
|
elif num_all_args == 3:
|
||||||
|
out = flash_attn.flash_attn_func(q, k, v)
|
||||||
|
elif optimized_attention.__name__ == 'attention_flash': # TODO
|
||||||
|
if 'flash_attn_3' not in globals():
|
||||||
|
import flash_attn_interface as flash_attn_3
|
||||||
|
if num_all_args == 2:
|
||||||
|
k, v = kv.unbind(dim=2)
|
||||||
|
out = flash_attn_3.flash_attn_func(q, k, v)
|
||||||
|
elif num_all_args == 3:
|
||||||
|
out = flash_attn_3.flash_attn_func(q, k, v)
|
||||||
|
elif optimized_attention.__name__ == 'attention_pytorch':
|
||||||
|
if 'sdpa' not in globals():
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
||||||
|
if num_all_args == 1:
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
elif num_all_args == 2:
|
||||||
|
k, v = kv.unbind(dim=2)
|
||||||
|
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||||
|
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||||
|
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||||
|
out = sdpa(q, k, v) # [N, H, L, C]
|
||||||
|
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
||||||
|
elif optimized_attention.__name__ == 'attention_basic':
|
||||||
|
if num_all_args == 1:
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
elif num_all_args == 2:
|
||||||
|
k, v = kv.unbind(dim=2)
|
||||||
|
q = q.shape[2] # TODO
|
||||||
|
out = optimized_attention(q, k, v)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
def sparse_windowed_scaled_dot_product_self_attention(
|
def sparse_windowed_scaled_dot_product_self_attention(
|
||||||
qkv,
|
qkv,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
|
|||||||
@ -3,7 +3,9 @@ import torch.nn.functional as F
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
|
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
|
||||||
from typing import Optional, Tuple, Literal, Union, List
|
from typing import Optional, Tuple, Literal, Union, List
|
||||||
from comfy.ldm.trellis2.attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention
|
from comfy.ldm.trellis2.attention import (
|
||||||
|
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
||||||
|
)
|
||||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||||
|
|
||||||
class SparseGELU(nn.GELU):
|
class SparseGELU(nn.GELU):
|
||||||
@ -103,6 +105,18 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
|||||||
k_embed = k.replace(self._rotary_embedding(k.feats, phases))
|
k_embed = k.replace(self._rotary_embedding(k.feats, phases))
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
||||||
|
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}"
|
||||||
|
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||||
|
if phases.shape[-1] < self.head_dim // 2:
|
||||||
|
padn = self.head_dim // 2 - phases.shape[-1]
|
||||||
|
phases = torch.cat([phases, torch.polar(
|
||||||
|
torch.ones(*phases.shape[:-1], padn, device=phases.device),
|
||||||
|
torch.zeros(*phases.shape[:-1], padn, device=phases.device)
|
||||||
|
)], dim=-1)
|
||||||
|
return phases
|
||||||
|
|
||||||
class SparseMultiHeadAttention(nn.Module):
|
class SparseMultiHeadAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -472,6 +486,292 @@ class SLatFlowModel(nn.Module):
|
|||||||
h = self.out_layer(h)
|
h = self.out_layer(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
class FeedForwardNet(nn.Module):
|
||||||
|
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(channels, int(channels * mlp_ratio)),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(int(channels * mlp_ratio), channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.mlp(x)
|
||||||
|
|
||||||
|
class MultiHeadRMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, heads: int):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** 0.5
|
||||||
|
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
ctx_channels: Optional[int]=None,
|
||||||
|
type: Literal["self", "cross"] = "self",
|
||||||
|
attn_mode: Literal["full", "windowed"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
use_rope: bool = False,
|
||||||
|
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
self.head_dim = channels // num_heads
|
||||||
|
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self._type = type
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
self.window_size = window_size
|
||||||
|
self.shift_window = shift_window
|
||||||
|
self.use_rope = use_rope
|
||||||
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
|
||||||
|
if self._type == "self":
|
||||||
|
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
||||||
|
else:
|
||||||
|
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
||||||
|
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
||||||
|
|
||||||
|
if self.qk_rms_norm:
|
||||||
|
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
||||||
|
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
||||||
|
|
||||||
|
self.to_out = nn.Linear(channels, channels)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
B, L, C = x.shape
|
||||||
|
if self._type == "self":
|
||||||
|
qkv = self.to_qkv(x)
|
||||||
|
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
||||||
|
|
||||||
|
if self.attn_mode == "full":
|
||||||
|
if self.qk_rms_norm or self.use_rope:
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
if self.qk_rms_norm:
|
||||||
|
q = self.q_rms_norm(q)
|
||||||
|
k = self.k_rms_norm(k)
|
||||||
|
if self.use_rope:
|
||||||
|
assert phases is not None, "Phases must be provided for RoPE"
|
||||||
|
q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
|
||||||
|
k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
|
||||||
|
h = scaled_dot_product_attention(q, k, v)
|
||||||
|
else:
|
||||||
|
h = scaled_dot_product_attention(qkv)
|
||||||
|
else:
|
||||||
|
Lkv = context.shape[1]
|
||||||
|
q = self.to_q(x)
|
||||||
|
kv = self.to_kv(context)
|
||||||
|
q = q.reshape(B, L, self.num_heads, -1)
|
||||||
|
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
||||||
|
if self.qk_rms_norm:
|
||||||
|
q = self.q_rms_norm(q)
|
||||||
|
k, v = kv.unbind(dim=2)
|
||||||
|
k = self.k_rms_norm(k)
|
||||||
|
h = scaled_dot_product_attention(q, k, v)
|
||||||
|
else:
|
||||||
|
h = scaled_dot_product_attention(q, kv)
|
||||||
|
h = h.reshape(B, L, -1)
|
||||||
|
h = self.to_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
class ModulatedTransformerCrossBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
ctx_channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "windowed"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_rope: bool = False,
|
||||||
|
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qk_rms_norm_cross: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
share_mod: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.share_mod = share_mod
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.self_attn = MultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="self",
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_window=shift_window,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rope=use_rope,
|
||||||
|
rope_freq=rope_freq,
|
||||||
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
)
|
||||||
|
self.cross_attn = MultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
ctx_channels=ctx_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="cross",
|
||||||
|
attn_mode="full",
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_rms_norm=qk_rms_norm_cross,
|
||||||
|
)
|
||||||
|
self.mlp = FeedForwardNet(
|
||||||
|
channels,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
if not share_mod:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(channels, 6 * channels, bias=True)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
|
||||||
|
|
||||||
|
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
if self.share_mod:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||||
|
h = self.norm1(x)
|
||||||
|
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||||
|
h = self.self_attn(h, phases=phases)
|
||||||
|
h = h * gate_msa.unsqueeze(1)
|
||||||
|
x = x + h
|
||||||
|
h = self.norm2(x)
|
||||||
|
h = self.cross_attn(h, context)
|
||||||
|
x = x + h
|
||||||
|
h = self.norm3(x)
|
||||||
|
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||||
|
h = self.mlp(h)
|
||||||
|
h = h * gate_mlp.unsqueeze(1)
|
||||||
|
x = x + h
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
return self._forward(x, mod, context, phases)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseStructureFlowModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resolution: int,
|
||||||
|
in_channels: int,
|
||||||
|
model_channels: int,
|
||||||
|
cond_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_blocks: int,
|
||||||
|
num_heads: Optional[int] = None,
|
||||||
|
num_head_channels: Optional[int] = 64,
|
||||||
|
mlp_ratio: float = 4,
|
||||||
|
pe_mode: Literal["ape", "rope"] = "ape",
|
||||||
|
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||||
|
dtype: str = 'float32',
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
share_mod: bool = False,
|
||||||
|
initialization: str = 'vanilla',
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qk_rms_norm_cross: bool = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.cond_channels = cond_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.num_heads = num_heads or model_channels // num_head_channels
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.pe_mode = pe_mode
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.share_mod = share_mod
|
||||||
|
self.initialization = initialization
|
||||||
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.t_embedder = TimestepEmbedder(model_channels)
|
||||||
|
if share_mod:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3)
|
||||||
|
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
|
||||||
|
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
||||||
|
rope_phases = pos_embedder(coords)
|
||||||
|
self.register_buffer("rope_phases", rope_phases)
|
||||||
|
|
||||||
|
if pe_mode != "rope":
|
||||||
|
self.rope_phases = None
|
||||||
|
|
||||||
|
self.input_layer = nn.Linear(in_channels, model_channels)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
ModulatedTransformerCrossBlock(
|
||||||
|
model_channels,
|
||||||
|
cond_channels,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
attn_mode='full',
|
||||||
|
use_checkpoint=self.use_checkpoint,
|
||||||
|
use_rope=(pe_mode == "rope"),
|
||||||
|
rope_freq=rope_freq,
|
||||||
|
share_mod=share_mod,
|
||||||
|
qk_rms_norm=self.qk_rms_norm,
|
||||||
|
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||||
|
)
|
||||||
|
for _ in range(num_blocks)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.out_layer = nn.Linear(model_channels, out_channels)
|
||||||
|
|
||||||
|
self.initialize_weights()
|
||||||
|
self.convert_to(self.dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
|
||||||
|
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
|
||||||
|
|
||||||
|
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
h = self.input_layer(h)
|
||||||
|
if self.pe_mode == "ape":
|
||||||
|
h = h + self.pos_emb[None]
|
||||||
|
t_emb = self.t_embedder(t)
|
||||||
|
if self.share_mod:
|
||||||
|
t_emb = self.adaLN_modulation(t_emb)
|
||||||
|
t_emb = manual_cast(t_emb, self.dtype)
|
||||||
|
h = manual_cast(h, self.dtype)
|
||||||
|
cond = manual_cast(cond, self.dtype)
|
||||||
|
for block in self.blocks:
|
||||||
|
h = block(h, t_emb, cond, self.rope_phases)
|
||||||
|
h = manual_cast(h, x.dtype)
|
||||||
|
h = F.layer_norm(h, h.shape[-1:])
|
||||||
|
h = self.out_layer(h)
|
||||||
|
|
||||||
|
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
class Trellis2(nn.Module):
|
class Trellis2(nn.Module):
|
||||||
def __init__(self, resolution,
|
def __init__(self, resolution,
|
||||||
in_channels = 32,
|
in_channels = 32,
|
||||||
@ -492,18 +792,24 @@ class Trellis2(nn.Module):
|
|||||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||||
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
||||||
}
|
}
|
||||||
# TODO: update the names/checkpoints
|
|
||||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||||
self.shape_generation = True
|
args.pop("out_channels")
|
||||||
|
args.pop("in_channels")
|
||||||
|
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
# TODO add mode
|
# TODO add mode
|
||||||
mode = kwargs.get("mode", "shape_generation")
|
mode = kwargs.get("mode", "shape_generation")
|
||||||
mode = "texture_generation" if mode == 1 else "shape_generation"
|
if mode != 0:
|
||||||
|
mode = "texture_generation" if mode == 2 else "shape_generation"
|
||||||
|
else:
|
||||||
|
mode = "structure_generation"
|
||||||
if mode == "shape_generation":
|
if mode == "shape_generation":
|
||||||
out = self.img2shape(x, timestep, context)
|
out = self.img2shape(x, timestep, context)
|
||||||
if mode == "texture_generation":
|
elif mode == "texture_generation":
|
||||||
out = self.shape2txt(x, timestep, context)
|
out = self.shape2txt(x, timestep, context)
|
||||||
|
else:
|
||||||
|
out = self.structure_model(x, timestep, context)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -1247,6 +1247,10 @@ class Trellis2(supported_models_base.BASE):
|
|||||||
"image_model": "trellis2"
|
"image_model": "trellis2"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.Trellis2
|
latent_format = latent_formats.Trellis2
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user