diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index add3f21ce..1dbbc4955 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -7,6 +7,7 @@ from comfy.ldm.trellis2.attention import ( sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder +from comfy.nested_tensor import NestedTensor class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -772,6 +773,11 @@ class SparseStructureFlowModel(nn.Module): return h +def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): + t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) + t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) + return t_new + class Trellis2(nn.Module): def __init__(self, resolution, in_channels = 32, @@ -798,18 +804,25 @@ class Trellis2(nn.Module): args.pop("in_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) - def forward(self, x, timestep, context, **kwargs): - # TODO add mode - mode = kwargs.get("mode", "shape_generation") - if mode != 0: - mode = "texture_generation" if mode == 2 else "shape_generation" - else: + def forward(self, x: NestedTensor, timestep, context, **kwargs): + x = x.tensors[0] + embeds = kwargs.get("embeds") + if not hasattr(x, "feats"): mode = "structure_generation" + else: + if x.feats.shape[1] == 32: + mode = "shape_generation" + else: + mode = "texture_generation" if mode == "shape_generation": - out = self.img2shape(x, timestep, context) + # TODO + out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) elif mode == "texture_generation": out = self.shape2txt(x, timestep, context) - else: + else: # structure + timestep = timestep_reshift(timestep) out = self.structure_model(x, timestep, context) + out = NestedTensor([out]) + out.generation_mode = mode return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 5dabf5246..584fa91ae 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -9,6 +9,17 @@ import numpy as np from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + class SparseConv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): super(SparseConv3d, self).__init__() @@ -1337,6 +1348,135 @@ def flexible_dual_grid_to_mesh( return mesh_vertices, mesh_triangles +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + return ChannelLayerNorm32(*args, **kwargs) + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class SparseStructureDecoder(nn.Module): + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h + class Vae(nn.Module): def __init__(self, config, operations=None): super().__init__() @@ -1363,6 +1503,14 @@ class Vae(nn.Module): block_args=[{}, {}, {}, {}, {}], ) + self.struct_dec = SparseStructureDecoder( + out_channels=1, + latent_channels=8, + num_res_blocks=2, + num_res_blocks_middle=2, + channels=[512, 128, 32], + ) + def decode_shape_slat(self, slat, resolution: int): self.shape_dec.set_resolution(resolution) return self.shape_dec(slat, return_subs=True) diff --git a/comfy/model_base.py b/comfy/model_base.py index a5fc81c4d..6bf4fadc9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1461,7 +1461,10 @@ class Trellis2(BaseModel): super().__init__(model_config, model_type, device, unet_model) def extra_conds(self, **kwargs): - return super().extra_conds(**kwargs) + out = super().extra_conds(**kwargs) + embeds = kwargs.get("embeds") + out["embeds"] = comfy.conds.CONDRegular(embeds) + return out class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 70bbbb29d..4a36e2fee 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -6,6 +6,7 @@ import comfy.model_management from PIL import Image import PIL import numpy as np +from comfy.nested_tensor import NestedTensor shape_slat_normalization = { "mean": torch.tensor([ @@ -131,7 +132,8 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae, resolution): + def execute(cls, samples: NestedTensor, vae, resolution): + samples = samples.tensors[0] std = shape_slat_normalization["std"] mean = shape_slat_normalization["mean"] samples = samples * std + mean @@ -157,9 +159,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, vae, shape_subs): - if shape_subs is None: - raise ValueError("Shape subs must be provided for texture generation") - + samples = samples.tensors[0] std = tex_slat_normalization["std"] mean = tex_slat_normalization["mean"] samples = samples * std + mean @@ -167,6 +167,28 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): mesh = vae.decode_tex_slat(samples, shape_subs) return mesh +class VaeDecodeStructureTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeStructureTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + ], + outputs=[ + IO.Mesh.Output("structure_output"), + ] + ) + + @classmethod + def execute(cls, samples, vae): + decoder = vae.struct_dec + decoded = decoder(samples)>0 + coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() + return coords + class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -189,8 +211,8 @@ class Trellis2Conditioning(IO.ComfyNode): # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that - positive = [[conditioning["cond_512"], {embeds}]] - negative = [[conditioning["cond_neg"], {embeds}]] + positive = [[conditioning["cond_512"], {"embeds": embeds}]] + negative = [[conditioning["cond_neg"], {"embeds": embeds}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -200,7 +222,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): node_id="EmptyLatentTrellis2", category="latent/3d", inputs=[ - IO.Latent.Input("structure_output"), + IO.Mesh.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -210,9 +232,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): # i will see what i have to do here - coords = structure_output or structure_output.coords + coords = structure_output # or structure_output.coords in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) + latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -222,7 +245,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): node_id="EmptyLatentTrellis2", category="latent/3d", inputs=[ - IO.Latent.Input("structure_output"), + IO.Mesh.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -234,6 +257,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) + latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -254,6 +278,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): def execute(cls, res, batch_size): in_channels = 32 latent = torch.randn(batch_size, in_channels, res, res, res) + latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) @@ -266,7 +291,8 @@ class Trellis2Extension(ComfyExtension): EmptyStructureLatentTrellis2, EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, - VaeDecodeShapeTrellis + VaeDecodeShapeTrellis, + VaeDecodeStructureTrellis2 ]