mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
comfy ops + color support in postprocess
This commit is contained in:
parent
ea255543e6
commit
4e14d42da1
@ -14,12 +14,12 @@ class SparseGELU(nn.GELU):
|
|||||||
return input.replace(super().forward(input.feats))
|
return input.replace(super().forward(input.feats))
|
||||||
|
|
||||||
class SparseFeedForwardNet(nn.Module):
|
class SparseFeedForwardNet(nn.Module):
|
||||||
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
SparseLinear(channels, int(channels * mlp_ratio)),
|
SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations),
|
||||||
SparseGELU(approximate="tanh"),
|
SparseGELU(approximate="tanh"),
|
||||||
SparseLinear(int(channels * mlp_ratio), channels),
|
SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
||||||
@ -37,10 +37,10 @@ class LayerNorm32(nn.LayerNorm):
|
|||||||
|
|
||||||
|
|
||||||
class SparseMultiHeadRMSNorm(nn.Module):
|
class SparseMultiHeadRMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, heads: int):
|
def __init__(self, dim: int, heads: int, device, dtype):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim ** 0.5
|
self.scale = dim ** 0.5
|
||||||
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||||
x_type = x.dtype
|
x_type = x.dtype
|
||||||
@ -56,14 +56,15 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
dim: int = 3,
|
dim: int = 3,
|
||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0)
|
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||||
|
device=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.rope_freq = rope_freq
|
self.rope_freq = rope_freq
|
||||||
self.freq_dim = head_dim // 2 // dim
|
self.freq_dim = head_dim // 2 // dim
|
||||||
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim
|
||||||
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
|
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
|
||||||
|
|
||||||
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
|
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
|
||||||
@ -148,6 +149,7 @@ class SparseMultiHeadAttention(nn.Module):
|
|||||||
use_rope: bool = False,
|
use_rope: bool = False,
|
||||||
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
||||||
qk_rms_norm: bool = False,
|
qk_rms_norm: bool = False,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -163,19 +165,19 @@ class SparseMultiHeadAttention(nn.Module):
|
|||||||
self.qk_rms_norm = qk_rms_norm
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
|
||||||
if self._type == "self":
|
if self._type == "self":
|
||||||
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
if self.qk_rms_norm:
|
if self.qk_rms_norm:
|
||||||
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
|
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||||
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
|
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Linear(channels, channels)
|
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
if use_rope:
|
if use_rope:
|
||||||
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq)
|
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||||
@ -267,14 +269,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
|||||||
qk_rms_norm_cross: bool = False,
|
qk_rms_norm_cross: bool = False,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
share_mod: bool = False,
|
share_mod: bool = False,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.share_mod = share_mod
|
self.share_mod = share_mod
|
||||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||||
self.self_attn = SparseMultiHeadAttention(
|
self.self_attn = SparseMultiHeadAttention(
|
||||||
channels,
|
channels,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
@ -286,6 +288,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
|||||||
use_rope=use_rope,
|
use_rope=use_rope,
|
||||||
rope_freq=rope_freq,
|
rope_freq=rope_freq,
|
||||||
qk_rms_norm=qk_rms_norm,
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
self.cross_attn = SparseMultiHeadAttention(
|
self.cross_attn = SparseMultiHeadAttention(
|
||||||
channels,
|
channels,
|
||||||
@ -295,18 +298,20 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
|||||||
attn_mode="full",
|
attn_mode="full",
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_rms_norm=qk_rms_norm_cross,
|
qk_rms_norm=qk_rms_norm_cross,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
self.mlp = SparseFeedForwardNet(
|
self.mlp = SparseFeedForwardNet(
|
||||||
channels,
|
channels,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
if not share_mod:
|
if not share_mod:
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(channels, 6 * channels, bias=True)
|
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
|
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||||
|
|
||||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
|
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
|
||||||
if self.share_mod:
|
if self.share_mod:
|
||||||
@ -376,10 +381,10 @@ class SLatFlowModel(nn.Module):
|
|||||||
if share_mod:
|
if share_mod:
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_layer = SparseLinear(in_channels, model_channels)
|
self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
ModulatedSparseTransformerCrossBlock(
|
ModulatedSparseTransformerCrossBlock(
|
||||||
@ -394,11 +399,12 @@ class SLatFlowModel(nn.Module):
|
|||||||
share_mod=self.share_mod,
|
share_mod=self.share_mod,
|
||||||
qk_rms_norm=self.qk_rms_norm,
|
qk_rms_norm=self.qk_rms_norm,
|
||||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(num_blocks)
|
for _ in range(num_blocks)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.out_layer = SparseLinear(model_channels, out_channels)
|
self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
@ -438,22 +444,22 @@ class SLatFlowModel(nn.Module):
|
|||||||
return h
|
return h
|
||||||
|
|
||||||
class FeedForwardNet(nn.Module):
|
class FeedForwardNet(nn.Module):
|
||||||
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
nn.Linear(channels, int(channels * mlp_ratio)),
|
operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype),
|
||||||
nn.GELU(approximate="tanh"),
|
nn.GELU(approximate="tanh"),
|
||||||
nn.Linear(int(channels * mlp_ratio), channels),
|
operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.mlp(x)
|
return self.mlp(x)
|
||||||
|
|
||||||
class MultiHeadRMSNorm(nn.Module):
|
class MultiHeadRMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, heads: int):
|
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim ** 0.5
|
self.scale = dim ** 0.5
|
||||||
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
||||||
@ -473,6 +479,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
use_rope: bool = False,
|
use_rope: bool = False,
|
||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||||
qk_rms_norm: bool = False,
|
qk_rms_norm: bool = False,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -488,16 +495,16 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.qk_rms_norm = qk_rms_norm
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
|
||||||
if self._type == "self":
|
if self._type == "self":
|
||||||
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
if self.qk_rms_norm:
|
if self.qk_rms_norm:
|
||||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Linear(channels, channels)
|
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
@ -554,13 +561,14 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
|||||||
qk_rms_norm_cross: bool = False,
|
qk_rms_norm_cross: bool = False,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
share_mod: bool = False,
|
share_mod: bool = False,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.share_mod = share_mod
|
self.share_mod = share_mod
|
||||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||||
self.self_attn = MultiHeadAttention(
|
self.self_attn = MultiHeadAttention(
|
||||||
channels,
|
channels,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
@ -572,6 +580,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
|||||||
use_rope=use_rope,
|
use_rope=use_rope,
|
||||||
rope_freq=rope_freq,
|
rope_freq=rope_freq,
|
||||||
qk_rms_norm=qk_rms_norm,
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
self.cross_attn = MultiHeadAttention(
|
self.cross_attn = MultiHeadAttention(
|
||||||
channels,
|
channels,
|
||||||
@ -581,18 +590,20 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
|||||||
attn_mode="full",
|
attn_mode="full",
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_rms_norm=qk_rms_norm_cross,
|
qk_rms_norm=qk_rms_norm_cross,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
self.mlp = FeedForwardNet(
|
self.mlp = FeedForwardNet(
|
||||||
channels,
|
channels,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
if not share_mod:
|
if not share_mod:
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(channels, 6 * channels, bias=True)
|
operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
|
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||||
|
|
||||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
if self.share_mod:
|
if self.share_mod:
|
||||||
@ -659,16 +670,17 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
self.qk_rms_norm = qk_rms_norm
|
self.qk_rms_norm = qk_rms_norm
|
||||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
self.t_embedder = TimestepEmbedder(model_channels, operations=operations)
|
self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
|
||||||
if share_mod:
|
if share_mod:
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3)
|
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device)
|
||||||
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
|
coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij')
|
||||||
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
||||||
rope_phases = pos_embedder(coords)
|
rope_phases = pos_embedder(coords)
|
||||||
self.register_buffer("rope_phases", rope_phases, persistent=False)
|
self.register_buffer("rope_phases", rope_phases, persistent=False)
|
||||||
@ -676,7 +688,7 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
if pe_mode != "rope":
|
if pe_mode != "rope":
|
||||||
self.rope_phases = None
|
self.rope_phases = None
|
||||||
|
|
||||||
self.input_layer = nn.Linear(in_channels, model_channels)
|
self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
ModulatedTransformerCrossBlock(
|
ModulatedTransformerCrossBlock(
|
||||||
@ -691,11 +703,12 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
share_mod=share_mod,
|
share_mod=share_mod,
|
||||||
qk_rms_norm=self.qk_rms_norm,
|
qk_rms_norm=self.qk_rms_norm,
|
||||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(num_blocks)
|
for _ in range(num_blocks)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.out_layer = nn.Linear(model_channels, out_channels)
|
self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||||
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
||||||
@ -745,6 +758,7 @@ class Trellis2(nn.Module):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
operations = operations or nn
|
||||||
# for some reason it passes num_heads = -1
|
# for some reason it passes num_heads = -1
|
||||||
if num_heads == -1:
|
if num_heads == -1:
|
||||||
num_heads = 12
|
num_heads = 12
|
||||||
@ -772,6 +786,7 @@ class Trellis2(nn.Module):
|
|||||||
coords = transformer_options.get("coords", None)
|
coords = transformer_options.get("coords", None)
|
||||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||||
is_512_run = False
|
is_512_run = False
|
||||||
|
timestep = timestep.to(self.dtype)
|
||||||
if mode == "shape_generation_512":
|
if mode == "shape_generation_512":
|
||||||
is_512_run = True
|
is_512_run = True
|
||||||
mode = "shape_generation"
|
mode = "shape_generation"
|
||||||
|
|||||||
@ -962,13 +962,17 @@ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
|
|||||||
feats = input.feats.unbind(dim)
|
feats = input.feats.unbind(dim)
|
||||||
return [input.replace(f) for f in feats]
|
return [input.replace(f) for f in feats]
|
||||||
|
|
||||||
class SparseLinear(nn.Linear):
|
# allow operations.Linear inheritance
|
||||||
def __init__(self, in_features, out_features, bias=True):
|
class SparseLinear:
|
||||||
super(SparseLinear, self).__init__(in_features, out_features, bias)
|
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs):
|
||||||
|
class _SparseLinear(operations.Linear):
|
||||||
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
|
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||||
return input.replace(super().forward(input.feats))
|
return input.replace(super().forward(input.feats))
|
||||||
|
|
||||||
|
return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs)
|
||||||
|
|
||||||
MIX_PRECISION_MODULES = (
|
MIX_PRECISION_MODULES = (
|
||||||
nn.Conv1d,
|
nn.Conv1d,
|
||||||
|
|||||||
@ -481,21 +481,25 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
def simplify_fn(vertices, faces, target=100000):
|
def simplify_fn(vertices, faces, colors=None, target=100000):
|
||||||
is_batched = vertices.ndim == 3
|
if vertices.ndim == 3:
|
||||||
if is_batched:
|
v_list, f_list, c_list = [], [], []
|
||||||
v_list, f_list = [], []
|
|
||||||
for i in range(vertices.shape[0]):
|
for i in range(vertices.shape[0]):
|
||||||
v_i, f_i = simplify_fn(vertices[i], faces[i], target)
|
c_in = colors[i] if colors is not None else None
|
||||||
|
v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target)
|
||||||
v_list.append(v_i)
|
v_list.append(v_i)
|
||||||
f_list.append(f_i)
|
f_list.append(f_i)
|
||||||
return torch.stack(v_list), torch.stack(f_list)
|
if c_i is not None:
|
||||||
|
c_list.append(c_i)
|
||||||
|
|
||||||
|
c_out = torch.stack(c_list) if len(c_list) > 0 else None
|
||||||
|
return torch.stack(v_list), torch.stack(f_list), c_out
|
||||||
|
|
||||||
if faces.shape[0] <= target:
|
if faces.shape[0] <= target:
|
||||||
return vertices, faces
|
return vertices, faces, colors
|
||||||
|
|
||||||
device = vertices.device
|
device = vertices.device
|
||||||
target_v = target / 2.0
|
target_v = max(target / 4.0, 1.0)
|
||||||
|
|
||||||
min_v = vertices.min(dim=0)[0]
|
min_v = vertices.min(dim=0)[0]
|
||||||
max_v = vertices.max(dim=0)[0]
|
max_v = vertices.max(dim=0)[0]
|
||||||
@ -510,14 +514,17 @@ def simplify_fn(vertices, faces, target=100000):
|
|||||||
|
|
||||||
new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device)
|
new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device)
|
||||||
counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device)
|
counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device)
|
||||||
|
|
||||||
new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices)
|
new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices)
|
||||||
counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1]))
|
counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1]))
|
||||||
|
|
||||||
new_vertices = new_vertices / counts.clamp(min=1)
|
new_vertices = new_vertices / counts.clamp(min=1)
|
||||||
|
|
||||||
new_faces = inverse_indices[faces]
|
new_colors = None
|
||||||
|
if colors is not None:
|
||||||
|
new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device)
|
||||||
|
new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors)
|
||||||
|
new_colors = new_colors / counts.clamp(min=1)
|
||||||
|
|
||||||
|
new_faces = inverse_indices[faces]
|
||||||
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
|
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
|
||||||
(new_faces[:, 1] != new_faces[:, 2]) & \
|
(new_faces[:, 1] != new_faces[:, 2]) & \
|
||||||
(new_faces[:, 2] != new_faces[:, 0])
|
(new_faces[:, 2] != new_faces[:, 0])
|
||||||
@ -527,7 +534,10 @@ def simplify_fn(vertices, faces, target=100000):
|
|||||||
final_vertices = new_vertices[unique_face_indices]
|
final_vertices = new_vertices[unique_face_indices]
|
||||||
final_faces = inv_face.reshape(-1, 3)
|
final_faces = inv_face.reshape(-1, 3)
|
||||||
|
|
||||||
return final_vertices, final_faces
|
# assign colors
|
||||||
|
final_colors = new_colors[unique_face_indices] if new_colors is not None else None
|
||||||
|
|
||||||
|
return final_vertices, final_faces, final_colors
|
||||||
|
|
||||||
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||||
is_batched = vertices.ndim == 3
|
is_batched = vertices.ndim == 3
|
||||||
@ -610,19 +620,6 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
|||||||
|
|
||||||
return v, f
|
return v, f
|
||||||
|
|
||||||
def make_double_sided(vertices, faces):
|
|
||||||
is_batched = vertices.ndim == 3
|
|
||||||
if is_batched:
|
|
||||||
f_list =[]
|
|
||||||
for i in range(faces.shape[0]):
|
|
||||||
f_inv = faces[i][:,[0, 2, 1]]
|
|
||||||
f_list.append(torch.cat([faces[i], f_inv], dim=0))
|
|
||||||
return vertices, torch.stack(f_list)
|
|
||||||
|
|
||||||
faces_inv = faces[:, [0, 2, 1]]
|
|
||||||
faces_double = torch.cat([faces, faces_inv], dim=0)
|
|
||||||
return vertices, faces_double
|
|
||||||
|
|
||||||
class PostProcessMesh(IO.ComfyNode):
|
class PostProcessMesh(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -641,19 +638,23 @@ class PostProcessMesh(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
||||||
|
# TODO: batched mode may break
|
||||||
mesh = copy.deepcopy(mesh)
|
mesh = copy.deepcopy(mesh)
|
||||||
verts, faces = mesh.vertices, mesh.faces
|
verts, faces = mesh.vertices, mesh.faces
|
||||||
|
colors = None
|
||||||
|
if hasattr(mesh, "colors"):
|
||||||
|
colors = mesh.colors
|
||||||
|
|
||||||
if fill_holes_perimeter > 0:
|
if fill_holes_perimeter > 0:
|
||||||
verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter)
|
verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter)
|
||||||
|
|
||||||
if simplify > 0 and faces.shape[0] > simplify:
|
if simplify > 0 and faces.shape[0] > simplify:
|
||||||
verts, faces = simplify_fn(verts, faces, target=simplify)
|
verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors)
|
||||||
|
|
||||||
verts, faces = make_double_sided(verts, faces)
|
|
||||||
|
|
||||||
mesh.vertices = verts
|
mesh.vertices = verts
|
||||||
mesh.faces = faces
|
mesh.faces = faces
|
||||||
|
if colors is not None:
|
||||||
|
mesh.colors = None
|
||||||
return IO.NodeOutput(mesh)
|
return IO.NodeOutput(mesh)
|
||||||
|
|
||||||
class Trellis2Extension(ComfyExtension):
|
class Trellis2Extension(ComfyExtension):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user