needed updates

This commit is contained in:
Yousef Rafat 2026-02-04 14:15:00 +02:00
parent 6624939505
commit 3002708fe3
4 changed files with 209 additions and 19 deletions

View File

@ -7,6 +7,7 @@ from comfy.ldm.trellis2.attention import (
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
)
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
from comfy.nested_tensor import NestedTensor
class SparseGELU(nn.GELU):
def forward(self, input: VarLenTensor) -> VarLenTensor:
@ -772,6 +773,11 @@ class SparseStructureFlowModel(nn.Module):
return h
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
return t_new
class Trellis2(nn.Module):
def __init__(self, resolution,
in_channels = 32,
@ -798,18 +804,25 @@ class Trellis2(nn.Module):
args.pop("in_channels")
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
def forward(self, x, timestep, context, **kwargs):
# TODO add mode
mode = kwargs.get("mode", "shape_generation")
if mode != 0:
mode = "texture_generation" if mode == 2 else "shape_generation"
else:
def forward(self, x: NestedTensor, timestep, context, **kwargs):
x = x.tensors[0]
embeds = kwargs.get("embeds")
if not hasattr(x, "feats"):
mode = "structure_generation"
else:
if x.feats.shape[1] == 32:
mode = "shape_generation"
else:
mode = "texture_generation"
if mode == "shape_generation":
out = self.img2shape(x, timestep, context)
# TODO
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
elif mode == "texture_generation":
out = self.shape2txt(x, timestep, context)
else:
else: # structure
timestep = timestep_reshift(timestep)
out = self.structure_model(x, timestep, context)
out = NestedTensor([out])
out.generation_mode = mode
return out

View File

@ -9,6 +9,17 @@ import numpy as np
from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
"""
3D pixel shuffle.
"""
B, C, H, W, D = x.shape
C_ = C // scale_factor**3
x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
return x
class SparseConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
super(SparseConv3d, self).__init__()
@ -1337,6 +1348,135 @@ def flexible_dual_grid_to_mesh(
return mesh_vertices, mesh_triangles
class ChannelLayerNorm32(LayerNorm32):
def forward(self, x: torch.Tensor) -> torch.Tensor:
DIM = x.dim()
x = x.permute(0, *range(2, DIM), 1).contiguous()
x = super().forward(x)
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
return x
class UpsampleBlock3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
mode = "conv",
):
assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if mode == "conv":
self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
elif mode == "nearest":
assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "conv"):
x = self.conv(x)
return pixel_shuffle_3d(x, 2)
else:
return F.interpolate(x, scale_factor=2, mode="nearest")
def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
return ChannelLayerNorm32(*args, **kwargs)
class ResBlock3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
norm_type = "layer",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.norm1 = norm_layer(norm_type, channels)
self.norm2 = norm_layer(norm_type, self.out_channels)
self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm1(x)
h = F.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = F.silu(h)
h = self.conv2(h)
h = h + self.skip_connection(x)
return h
class SparseStructureDecoder(nn.Module):
def __init__(
self,
out_channels: int,
latent_channels: int,
num_res_blocks: int,
channels: List[int],
num_res_blocks_middle: int = 2,
norm_type = "layer",
use_fp16: bool = False,
):
super().__init__()
self.out_channels = out_channels
self.latent_channels = latent_channels
self.num_res_blocks = num_res_blocks
self.channels = channels
self.num_res_blocks_middle = num_res_blocks_middle
self.norm_type = norm_type
self.use_fp16 = use_fp16
self.dtype = torch.float16 if use_fp16 else torch.float32
self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
self.middle_block = nn.Sequential(*[
ResBlock3d(channels[0], channels[0])
for _ in range(num_res_blocks_middle)
])
self.blocks = nn.ModuleList([])
for i, ch in enumerate(channels):
self.blocks.extend([
ResBlock3d(ch, ch)
for _ in range(num_res_blocks)
])
if i < len(channels) - 1:
self.blocks.append(
UpsampleBlock3d(ch, channels[i+1])
)
self.out_layer = nn.Sequential(
norm_layer(norm_type, channels[-1]),
nn.SiLU(),
nn.Conv3d(channels[-1], out_channels, 3, padding=1)
)
if use_fp16:
self.convert_to_fp16()
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.input_layer(x)
h = h.type(self.dtype)
h = self.middle_block(h)
for block in self.blocks:
h = block(h)
h = h.type(x.dtype)
h = self.out_layer(h)
return h
class Vae(nn.Module):
def __init__(self, config, operations=None):
super().__init__()
@ -1363,6 +1503,14 @@ class Vae(nn.Module):
block_args=[{}, {}, {}, {}, {}],
)
self.struct_dec = SparseStructureDecoder(
out_channels=1,
latent_channels=8,
num_res_blocks=2,
num_res_blocks_middle=2,
channels=[512, 128, 32],
)
def decode_shape_slat(self, slat, resolution: int):
self.shape_dec.set_resolution(resolution)
return self.shape_dec(slat, return_subs=True)

View File

@ -1461,7 +1461,10 @@ class Trellis2(BaseModel):
super().__init__(model_config, model_type, device, unet_model)
def extra_conds(self, **kwargs):
return super().extra_conds(**kwargs)
out = super().extra_conds(**kwargs)
embeds = kwargs.get("embeds")
out["embeds"] = comfy.conds.CONDRegular(embeds)
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):

View File

@ -6,6 +6,7 @@ import comfy.model_management
from PIL import Image
import PIL
import numpy as np
from comfy.nested_tensor import NestedTensor
shape_slat_normalization = {
"mean": torch.tensor([
@ -131,7 +132,8 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
)
@classmethod
def execute(cls, samples, vae, resolution):
def execute(cls, samples: NestedTensor, vae, resolution):
samples = samples.tensors[0]
std = shape_slat_normalization["std"]
mean = shape_slat_normalization["mean"]
samples = samples * std + mean
@ -157,9 +159,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
@classmethod
def execute(cls, samples, vae, shape_subs):
if shape_subs is None:
raise ValueError("Shape subs must be provided for texture generation")
samples = samples.tensors[0]
std = tex_slat_normalization["std"]
mean = tex_slat_normalization["mean"]
samples = samples * std + mean
@ -167,6 +167,28 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
mesh = vae.decode_tex_slat(samples, shape_subs)
return mesh
class VaeDecodeStructureTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VaeDecodeStructureTrellis2",
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
],
outputs=[
IO.Mesh.Output("structure_output"),
]
)
@classmethod
def execute(cls, samples, vae):
decoder = vae.struct_dec
decoded = decoder(samples)>0
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
return coords
class Trellis2Conditioning(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -189,8 +211,8 @@ class Trellis2Conditioning(IO.ComfyNode):
# could make 1024 an option
conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color)
embeds = conditioning["cond_1024"] # should add that
positive = [[conditioning["cond_512"], {embeds}]]
negative = [[conditioning["cond_neg"], {embeds}]]
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
negative = [[conditioning["cond_neg"], {"embeds": embeds}]]
return IO.NodeOutput(positive, negative)
class EmptyShapeLatentTrellis2(IO.ComfyNode):
@ -200,7 +222,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
node_id="EmptyLatentTrellis2",
category="latent/3d",
inputs=[
IO.Latent.Input("structure_output"),
IO.Mesh.Input("structure_output"),
],
outputs=[
IO.Latent.Output(),
@ -210,9 +232,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
@classmethod
def execute(cls, structure_output):
# i will see what i have to do here
coords = structure_output or structure_output.coords
coords = structure_output # or structure_output.coords
in_channels = 32
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
latent = NestedTensor([latent])
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyTextureLatentTrellis2(IO.ComfyNode):
@ -222,7 +245,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
node_id="EmptyLatentTrellis2",
category="latent/3d",
inputs=[
IO.Latent.Input("structure_output"),
IO.Mesh.Input("structure_output"),
],
outputs=[
IO.Latent.Output(),
@ -234,6 +257,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
# TODO
in_channels = 32
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
latent = NestedTensor([latent])
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyStructureLatentTrellis2(IO.ComfyNode):
@ -254,6 +278,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
def execute(cls, res, batch_size):
in_channels = 32
latent = torch.randn(batch_size, in_channels, res, res, res)
latent = NestedTensor([latent])
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
@ -266,7 +291,8 @@ class Trellis2Extension(ComfyExtension):
EmptyStructureLatentTrellis2,
EmptyTextureLatentTrellis2,
VaeDecodeTextureTrellis,
VaeDecodeShapeTrellis
VaeDecodeShapeTrellis,
VaeDecodeStructureTrellis2
]