mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
comfy ops + color support in postprocess
This commit is contained in:
parent
ea255543e6
commit
4e14d42da1
@ -14,12 +14,12 @@ class SparseGELU(nn.GELU):
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
class SparseFeedForwardNet(nn.Module):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
SparseLinear(channels, int(channels * mlp_ratio)),
|
||||
SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations),
|
||||
SparseGELU(approximate="tanh"),
|
||||
SparseLinear(int(channels * mlp_ratio), channels),
|
||||
SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations),
|
||||
)
|
||||
|
||||
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
||||
@ -37,10 +37,10 @@ class LayerNorm32(nn.LayerNorm):
|
||||
|
||||
|
||||
class SparseMultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int):
|
||||
def __init__(self, dim: int, heads: int, device, dtype):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
x_type = x.dtype
|
||||
@ -56,14 +56,15 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
||||
self,
|
||||
head_dim: int,
|
||||
dim: int = 3,
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0)
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.dim = dim
|
||||
self.rope_freq = rope_freq
|
||||
self.freq_dim = head_dim // 2 // dim
|
||||
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
||||
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim
|
||||
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
|
||||
|
||||
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
@ -148,6 +149,7 @@ class SparseMultiHeadAttention(nn.Module):
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -163,19 +165,19 @@ class SparseMultiHeadAttention(nn.Module):
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
if self._type == "self":
|
||||
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
else:
|
||||
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
||||
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
|
||||
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
|
||||
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
|
||||
self.to_out = nn.Linear(channels, channels)
|
||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||
|
||||
if use_rope:
|
||||
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq)
|
||||
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device)
|
||||
|
||||
@staticmethod
|
||||
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
@ -267,14 +269,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.self_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
@ -286,6 +288,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
use_rope=use_rope,
|
||||
rope_freq=rope_freq,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.cross_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
@ -295,18 +298,20 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 6 * channels, bias=True)
|
||||
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
|
||||
if self.share_mod:
|
||||
@ -376,10 +381,10 @@ class SLatFlowModel(nn.Module):
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
||||
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.input_layer = SparseLinear(in_channels, model_channels)
|
||||
self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedSparseTransformerCrossBlock(
|
||||
@ -394,11 +399,12 @@ class SLatFlowModel(nn.Module):
|
||||
share_mod=self.share_mod,
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
self.out_layer = SparseLinear(model_channels, out_channels)
|
||||
self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
@ -438,22 +444,22 @@ class SLatFlowModel(nn.Module):
|
||||
return h
|
||||
|
||||
class FeedForwardNet(nn.Module):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(channels, int(channels * mlp_ratio)),
|
||||
operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(int(channels * mlp_ratio), channels),
|
||||
operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int):
|
||||
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
||||
@ -473,6 +479,7 @@ class MultiHeadAttention(nn.Module):
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -488,16 +495,16 @@ class MultiHeadAttention(nn.Module):
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
if self._type == "self":
|
||||
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
||||
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
|
||||
self.to_out = nn.Linear(channels, channels)
|
||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
B, L, C = x.shape
|
||||
@ -554,13 +561,14 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.self_attn = MultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
@ -572,6 +580,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
use_rope=use_rope,
|
||||
rope_freq=rope_freq,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.cross_attn = MultiHeadAttention(
|
||||
channels,
|
||||
@ -581,18 +590,20 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.mlp = FeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 6 * channels, bias=True)
|
||||
operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.share_mod:
|
||||
@ -659,16 +670,17 @@ class SparseStructureFlowModel(nn.Module):
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, operations=operations)
|
||||
self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
||||
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3)
|
||||
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij')
|
||||
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device)
|
||||
coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij')
|
||||
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
||||
rope_phases = pos_embedder(coords)
|
||||
self.register_buffer("rope_phases", rope_phases, persistent=False)
|
||||
@ -676,7 +688,7 @@ class SparseStructureFlowModel(nn.Module):
|
||||
if pe_mode != "rope":
|
||||
self.rope_phases = None
|
||||
|
||||
self.input_layer = nn.Linear(in_channels, model_channels)
|
||||
self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedTransformerCrossBlock(
|
||||
@ -691,11 +703,12 @@ class SparseStructureFlowModel(nn.Module):
|
||||
share_mod=share_mod,
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
self.out_layer = nn.Linear(model_channels, out_channels)
|
||||
self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
||||
@ -745,6 +758,7 @@ class Trellis2(nn.Module):
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operations = operations or nn
|
||||
# for some reason it passes num_heads = -1
|
||||
if num_heads == -1:
|
||||
num_heads = 12
|
||||
@ -772,6 +786,7 @@ class Trellis2(nn.Module):
|
||||
coords = transformer_options.get("coords", None)
|
||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||
is_512_run = False
|
||||
timestep = timestep.to(self.dtype)
|
||||
if mode == "shape_generation_512":
|
||||
is_512_run = True
|
||||
mode = "shape_generation"
|
||||
|
||||
@ -962,13 +962,17 @@ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
|
||||
feats = input.feats.unbind(dim)
|
||||
return [input.replace(f) for f in feats]
|
||||
|
||||
class SparseLinear(nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super(SparseLinear, self).__init__(in_features, out_features, bias)
|
||||
# allow operations.Linear inheritance
|
||||
class SparseLinear:
|
||||
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs):
|
||||
class _SparseLinear(operations.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs)
|
||||
|
||||
MIX_PRECISION_MODULES = (
|
||||
nn.Conv1d,
|
||||
|
||||
@ -481,21 +481,25 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
def simplify_fn(vertices, faces, target=100000):
|
||||
is_batched = vertices.ndim == 3
|
||||
if is_batched:
|
||||
v_list, f_list = [], []
|
||||
def simplify_fn(vertices, faces, colors=None, target=100000):
|
||||
if vertices.ndim == 3:
|
||||
v_list, f_list, c_list = [], [], []
|
||||
for i in range(vertices.shape[0]):
|
||||
v_i, f_i = simplify_fn(vertices[i], faces[i], target)
|
||||
c_in = colors[i] if colors is not None else None
|
||||
v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target)
|
||||
v_list.append(v_i)
|
||||
f_list.append(f_i)
|
||||
return torch.stack(v_list), torch.stack(f_list)
|
||||
if c_i is not None:
|
||||
c_list.append(c_i)
|
||||
|
||||
c_out = torch.stack(c_list) if len(c_list) > 0 else None
|
||||
return torch.stack(v_list), torch.stack(f_list), c_out
|
||||
|
||||
if faces.shape[0] <= target:
|
||||
return vertices, faces
|
||||
return vertices, faces, colors
|
||||
|
||||
device = vertices.device
|
||||
target_v = target / 2.0
|
||||
target_v = max(target / 4.0, 1.0)
|
||||
|
||||
min_v = vertices.min(dim=0)[0]
|
||||
max_v = vertices.max(dim=0)[0]
|
||||
@ -510,14 +514,17 @@ def simplify_fn(vertices, faces, target=100000):
|
||||
|
||||
new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device)
|
||||
counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device)
|
||||
|
||||
new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices)
|
||||
counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1]))
|
||||
|
||||
new_vertices = new_vertices / counts.clamp(min=1)
|
||||
|
||||
new_faces = inverse_indices[faces]
|
||||
new_colors = None
|
||||
if colors is not None:
|
||||
new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device)
|
||||
new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors)
|
||||
new_colors = new_colors / counts.clamp(min=1)
|
||||
|
||||
new_faces = inverse_indices[faces]
|
||||
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
|
||||
(new_faces[:, 1] != new_faces[:, 2]) & \
|
||||
(new_faces[:, 2] != new_faces[:, 0])
|
||||
@ -527,7 +534,10 @@ def simplify_fn(vertices, faces, target=100000):
|
||||
final_vertices = new_vertices[unique_face_indices]
|
||||
final_faces = inv_face.reshape(-1, 3)
|
||||
|
||||
return final_vertices, final_faces
|
||||
# assign colors
|
||||
final_colors = new_colors[unique_face_indices] if new_colors is not None else None
|
||||
|
||||
return final_vertices, final_faces, final_colors
|
||||
|
||||
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||
is_batched = vertices.ndim == 3
|
||||
@ -610,19 +620,6 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||
|
||||
return v, f
|
||||
|
||||
def make_double_sided(vertices, faces):
|
||||
is_batched = vertices.ndim == 3
|
||||
if is_batched:
|
||||
f_list =[]
|
||||
for i in range(faces.shape[0]):
|
||||
f_inv = faces[i][:,[0, 2, 1]]
|
||||
f_list.append(torch.cat([faces[i], f_inv], dim=0))
|
||||
return vertices, torch.stack(f_list)
|
||||
|
||||
faces_inv = faces[:, [0, 2, 1]]
|
||||
faces_double = torch.cat([faces, faces_inv], dim=0)
|
||||
return vertices, faces_double
|
||||
|
||||
class PostProcessMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -641,19 +638,23 @@ class PostProcessMesh(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
||||
# TODO: batched mode may break
|
||||
mesh = copy.deepcopy(mesh)
|
||||
verts, faces = mesh.vertices, mesh.faces
|
||||
colors = None
|
||||
if hasattr(mesh, "colors"):
|
||||
colors = mesh.colors
|
||||
|
||||
if fill_holes_perimeter > 0:
|
||||
verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter)
|
||||
|
||||
if simplify > 0 and faces.shape[0] > simplify:
|
||||
verts, faces = simplify_fn(verts, faces, target=simplify)
|
||||
|
||||
verts, faces = make_double_sided(verts, faces)
|
||||
verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors)
|
||||
|
||||
mesh.vertices = verts
|
||||
mesh.faces = faces
|
||||
if colors is not None:
|
||||
mesh.colors = None
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
class Trellis2Extension(ComfyExtension):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user