comfy ops + color support in postprocess

This commit is contained in:
Yousef Rafat 2026-04-10 16:12:23 +02:00
parent ea255543e6
commit 4e14d42da1
3 changed files with 98 additions and 78 deletions

View File

@ -14,12 +14,12 @@ class SparseGELU(nn.GELU):
return input.replace(super().forward(input.feats))
class SparseFeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
SparseLinear(channels, int(channels * mlp_ratio)),
SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations),
SparseGELU(approximate="tanh"),
SparseLinear(int(channels * mlp_ratio), channels),
SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations),
)
def forward(self, x: VarLenTensor) -> VarLenTensor:
@ -37,10 +37,10 @@ class LayerNorm32(nn.LayerNorm):
class SparseMultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int):
def __init__(self, dim: int, heads: int, device, dtype):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim))
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
x_type = x.dtype
@ -56,14 +56,15 @@ class SparseRotaryPositionEmbedder(nn.Module):
self,
head_dim: int,
dim: int = 3,
rope_freq: Tuple[float, float] = (1.0, 10000.0)
rope_freq: Tuple[float, float] = (1.0, 10000.0),
device=None
):
super().__init__()
self.head_dim = head_dim
self.dim = dim
self.rope_freq = rope_freq
self.freq_dim = head_dim // 2 // dim
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
@ -148,6 +149,7 @@ class SparseMultiHeadAttention(nn.Module):
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
device=None, dtype=None, operations=None
):
super().__init__()
@ -163,19 +165,19 @@ class SparseMultiHeadAttention(nn.Module):
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype)
else:
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
if self.qk_rms_norm:
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.to_out = nn.Linear(channels, channels)
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
if use_rope:
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq)
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device)
@staticmethod
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
@ -267,14 +269,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
device=None, dtype=None, operations=None
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.self_attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
@ -286,6 +288,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
device=device, dtype=dtype, operations=operations
)
self.cross_attn = SparseMultiHeadAttention(
channels,
@ -295,18 +298,20 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
device=device, dtype=dtype, operations=operations
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 6 * channels, bias=True)
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
if self.share_mod:
@ -376,10 +381,10 @@ class SLatFlowModel(nn.Module):
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(model_channels, 6 * model_channels, bias=True)
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
)
self.input_layer = SparseLinear(in_channels, model_channels)
self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations)
self.blocks = nn.ModuleList([
ModulatedSparseTransformerCrossBlock(
@ -394,11 +399,12 @@ class SLatFlowModel(nn.Module):
share_mod=self.share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_blocks)
])
self.out_layer = SparseLinear(model_channels, out_channels)
self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations)
@property
def device(self) -> torch.device:
@ -438,22 +444,22 @@ class SLatFlowModel(nn.Module):
return h
class FeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype),
nn.GELU(approximate="tanh"),
nn.Linear(int(channels * mlp_ratio), channels),
operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int):
def __init__(self, dim: int, heads: int, device=None, dtype=None):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim))
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
@ -473,6 +479,7 @@ class MultiHeadAttention(nn.Module):
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
device=None, dtype=None, operations=None
):
super().__init__()
@ -488,16 +495,16 @@ class MultiHeadAttention(nn.Module):
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
else:
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
if self.qk_rms_norm:
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.to_out = nn.Linear(channels, channels)
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
B, L, C = x.shape
@ -554,13 +561,14 @@ class ModulatedTransformerCrossBlock(nn.Module):
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
device=None, dtype=None, operations=None
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.self_attn = MultiHeadAttention(
channels,
num_heads=num_heads,
@ -572,6 +580,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
device=device, dtype=dtype, operations=operations
)
self.cross_attn = MultiHeadAttention(
channels,
@ -581,18 +590,20 @@ class ModulatedTransformerCrossBlock(nn.Module):
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
device=device, dtype=dtype, operations=operations
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 6 * channels, bias=True)
operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.share_mod:
@ -659,16 +670,17 @@ class SparseStructureFlowModel(nn.Module):
self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross
self.dtype = dtype
self.device = device
self.t_embedder = TimestepEmbedder(model_channels, operations=operations)
self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(model_channels, 6 * model_channels, bias=True)
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
)
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3)
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device)
coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij')
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
rope_phases = pos_embedder(coords)
self.register_buffer("rope_phases", rope_phases, persistent=False)
@ -676,7 +688,7 @@ class SparseStructureFlowModel(nn.Module):
if pe_mode != "rope":
self.rope_phases = None
self.input_layer = nn.Linear(in_channels, model_channels)
self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype)
self.blocks = nn.ModuleList([
ModulatedTransformerCrossBlock(
@ -691,11 +703,12 @@ class SparseStructureFlowModel(nn.Module):
share_mod=share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_blocks)
])
self.out_layer = nn.Linear(model_channels, out_channels)
self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
@ -745,6 +758,7 @@ class Trellis2(nn.Module):
super().__init__()
self.dtype = dtype
operations = operations or nn
# for some reason it passes num_heads = -1
if num_heads == -1:
num_heads = 12
@ -772,6 +786,7 @@ class Trellis2(nn.Module):
coords = transformer_options.get("coords", None)
mode = transformer_options.get("generation_mode", "structure_generation")
is_512_run = False
timestep = timestep.to(self.dtype)
if mode == "shape_generation_512":
is_512_run = True
mode = "shape_generation"

View File

@ -962,13 +962,17 @@ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
feats = input.feats.unbind(dim)
return [input.replace(f) for f in feats]
class SparseLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(SparseLinear, self).__init__(in_features, out_features, bias)
# allow operations.Linear inheritance
class SparseLinear:
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs):
class _SparseLinear(operations.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
def forward(self, input: VarLenTensor) -> VarLenTensor:
return input.replace(super().forward(input.feats))
def forward(self, input: VarLenTensor) -> VarLenTensor:
return input.replace(super().forward(input.feats))
return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs)
MIX_PRECISION_MODULES = (
nn.Conv1d,

View File

@ -481,21 +481,25 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
def simplify_fn(vertices, faces, target=100000):
is_batched = vertices.ndim == 3
if is_batched:
v_list, f_list = [], []
def simplify_fn(vertices, faces, colors=None, target=100000):
if vertices.ndim == 3:
v_list, f_list, c_list = [], [], []
for i in range(vertices.shape[0]):
v_i, f_i = simplify_fn(vertices[i], faces[i], target)
c_in = colors[i] if colors is not None else None
v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target)
v_list.append(v_i)
f_list.append(f_i)
return torch.stack(v_list), torch.stack(f_list)
if c_i is not None:
c_list.append(c_i)
c_out = torch.stack(c_list) if len(c_list) > 0 else None
return torch.stack(v_list), torch.stack(f_list), c_out
if faces.shape[0] <= target:
return vertices, faces
return vertices, faces, colors
device = vertices.device
target_v = target / 2.0
target_v = max(target / 4.0, 1.0)
min_v = vertices.min(dim=0)[0]
max_v = vertices.max(dim=0)[0]
@ -510,14 +514,17 @@ def simplify_fn(vertices, faces, target=100000):
new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device)
counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device)
new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices)
counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1]))
new_vertices = new_vertices / counts.clamp(min=1)
new_faces = inverse_indices[faces]
new_colors = None
if colors is not None:
new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device)
new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors)
new_colors = new_colors / counts.clamp(min=1)
new_faces = inverse_indices[faces]
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
(new_faces[:, 1] != new_faces[:, 2]) & \
(new_faces[:, 2] != new_faces[:, 0])
@ -527,7 +534,10 @@ def simplify_fn(vertices, faces, target=100000):
final_vertices = new_vertices[unique_face_indices]
final_faces = inv_face.reshape(-1, 3)
return final_vertices, final_faces
# assign colors
final_colors = new_colors[unique_face_indices] if new_colors is not None else None
return final_vertices, final_faces, final_colors
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
is_batched = vertices.ndim == 3
@ -610,19 +620,6 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
return v, f
def make_double_sided(vertices, faces):
is_batched = vertices.ndim == 3
if is_batched:
f_list =[]
for i in range(faces.shape[0]):
f_inv = faces[i][:,[0, 2, 1]]
f_list.append(torch.cat([faces[i], f_inv], dim=0))
return vertices, torch.stack(f_list)
faces_inv = faces[:, [0, 2, 1]]
faces_double = torch.cat([faces, faces_inv], dim=0)
return vertices, faces_double
class PostProcessMesh(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -641,19 +638,23 @@ class PostProcessMesh(IO.ComfyNode):
@classmethod
def execute(cls, mesh, simplify, fill_holes_perimeter):
# TODO: batched mode may break
mesh = copy.deepcopy(mesh)
verts, faces = mesh.vertices, mesh.faces
colors = None
if hasattr(mesh, "colors"):
colors = mesh.colors
if fill_holes_perimeter > 0:
verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter)
if simplify > 0 and faces.shape[0] > simplify:
verts, faces = simplify_fn(verts, faces, target=simplify)
verts, faces = make_double_sided(verts, faces)
verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors)
mesh.vertices = verts
mesh.faces = faces
if colors is not None:
mesh.colors = None
return IO.NodeOutput(mesh)
class Trellis2Extension(ComfyExtension):