mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-03 13:52:31 +08:00
needed updates
This commit is contained in:
parent
6624939505
commit
3002708fe3
@ -7,6 +7,7 @@ from comfy.ldm.trellis2.attention import (
|
|||||||
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
||||||
)
|
)
|
||||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||||
|
from comfy.nested_tensor import NestedTensor
|
||||||
|
|
||||||
class SparseGELU(nn.GELU):
|
class SparseGELU(nn.GELU):
|
||||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||||
@ -772,6 +773,11 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
|
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
||||||
|
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
|
||||||
|
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
|
||||||
|
return t_new
|
||||||
|
|
||||||
class Trellis2(nn.Module):
|
class Trellis2(nn.Module):
|
||||||
def __init__(self, resolution,
|
def __init__(self, resolution,
|
||||||
in_channels = 32,
|
in_channels = 32,
|
||||||
@ -798,18 +804,25 @@ class Trellis2(nn.Module):
|
|||||||
args.pop("in_channels")
|
args.pop("in_channels")
|
||||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x: NestedTensor, timestep, context, **kwargs):
|
||||||
# TODO add mode
|
x = x.tensors[0]
|
||||||
mode = kwargs.get("mode", "shape_generation")
|
embeds = kwargs.get("embeds")
|
||||||
if mode != 0:
|
if not hasattr(x, "feats"):
|
||||||
mode = "texture_generation" if mode == 2 else "shape_generation"
|
|
||||||
else:
|
|
||||||
mode = "structure_generation"
|
mode = "structure_generation"
|
||||||
|
else:
|
||||||
|
if x.feats.shape[1] == 32:
|
||||||
|
mode = "shape_generation"
|
||||||
|
else:
|
||||||
|
mode = "texture_generation"
|
||||||
if mode == "shape_generation":
|
if mode == "shape_generation":
|
||||||
out = self.img2shape(x, timestep, context)
|
# TODO
|
||||||
|
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
||||||
elif mode == "texture_generation":
|
elif mode == "texture_generation":
|
||||||
out = self.shape2txt(x, timestep, context)
|
out = self.shape2txt(x, timestep, context)
|
||||||
else:
|
else: # structure
|
||||||
|
timestep = timestep_reshift(timestep)
|
||||||
out = self.structure_model(x, timestep, context)
|
out = self.structure_model(x, timestep, context)
|
||||||
|
|
||||||
|
out = NestedTensor([out])
|
||||||
|
out.generation_mode = mode
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -9,6 +9,17 @@ import numpy as np
|
|||||||
from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
3D pixel shuffle.
|
||||||
|
"""
|
||||||
|
B, C, H, W, D = x.shape
|
||||||
|
C_ = C // scale_factor**3
|
||||||
|
x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
|
||||||
|
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
||||||
|
x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
|
||||||
|
return x
|
||||||
|
|
||||||
class SparseConv3d(nn.Module):
|
class SparseConv3d(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
||||||
super(SparseConv3d, self).__init__()
|
super(SparseConv3d, self).__init__()
|
||||||
@ -1337,6 +1348,135 @@ def flexible_dual_grid_to_mesh(
|
|||||||
|
|
||||||
return mesh_vertices, mesh_triangles
|
return mesh_vertices, mesh_triangles
|
||||||
|
|
||||||
|
class ChannelLayerNorm32(LayerNorm32):
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
DIM = x.dim()
|
||||||
|
x = x.permute(0, *range(2, DIM), 1).contiguous()
|
||||||
|
x = super().forward(x)
|
||||||
|
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
class UpsampleBlock3d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
mode = "conv",
|
||||||
|
):
|
||||||
|
assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
if mode == "conv":
|
||||||
|
self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
|
||||||
|
elif mode == "nearest":
|
||||||
|
assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if hasattr(self, "conv"):
|
||||||
|
x = self.conv(x)
|
||||||
|
return pixel_shuffle_3d(x, 2)
|
||||||
|
else:
|
||||||
|
return F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
|
|
||||||
|
def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
|
||||||
|
return ChannelLayerNorm32(*args, **kwargs)
|
||||||
|
|
||||||
|
class ResBlock3d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
norm_type = "layer",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(norm_type, channels)
|
||||||
|
self.norm2 = norm_layer(norm_type, self.out_channels)
|
||||||
|
self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
|
||||||
|
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
h = self.norm1(x)
|
||||||
|
h = F.silu(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = F.silu(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
h = h + self.skip_connection(x)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class SparseStructureDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_channels: int,
|
||||||
|
latent_channels: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
channels: List[int],
|
||||||
|
num_res_blocks_middle: int = 2,
|
||||||
|
norm_type = "layer",
|
||||||
|
use_fp16: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.channels = channels
|
||||||
|
self.num_res_blocks_middle = num_res_blocks_middle
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
|
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||||
|
|
||||||
|
self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
|
||||||
|
|
||||||
|
self.middle_block = nn.Sequential(*[
|
||||||
|
ResBlock3d(channels[0], channels[0])
|
||||||
|
for _ in range(num_res_blocks_middle)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([])
|
||||||
|
for i, ch in enumerate(channels):
|
||||||
|
self.blocks.extend([
|
||||||
|
ResBlock3d(ch, ch)
|
||||||
|
for _ in range(num_res_blocks)
|
||||||
|
])
|
||||||
|
if i < len(channels) - 1:
|
||||||
|
self.blocks.append(
|
||||||
|
UpsampleBlock3d(ch, channels[i+1])
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_layer = nn.Sequential(
|
||||||
|
norm_layer(norm_type, channels[-1]),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Conv3d(channels[-1], out_channels, 3, padding=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_fp16:
|
||||||
|
self.convert_to_fp16()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
h = self.input_layer(x)
|
||||||
|
|
||||||
|
h = h.type(self.dtype)
|
||||||
|
|
||||||
|
h = self.middle_block(h)
|
||||||
|
for block in self.blocks:
|
||||||
|
h = block(h)
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
h = self.out_layer(h)
|
||||||
|
return h
|
||||||
|
|
||||||
class Vae(nn.Module):
|
class Vae(nn.Module):
|
||||||
def __init__(self, config, operations=None):
|
def __init__(self, config, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1363,6 +1503,14 @@ class Vae(nn.Module):
|
|||||||
block_args=[{}, {}, {}, {}, {}],
|
block_args=[{}, {}, {}, {}, {}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.struct_dec = SparseStructureDecoder(
|
||||||
|
out_channels=1,
|
||||||
|
latent_channels=8,
|
||||||
|
num_res_blocks=2,
|
||||||
|
num_res_blocks_middle=2,
|
||||||
|
channels=[512, 128, 32],
|
||||||
|
)
|
||||||
|
|
||||||
def decode_shape_slat(self, slat, resolution: int):
|
def decode_shape_slat(self, slat, resolution: int):
|
||||||
self.shape_dec.set_resolution(resolution)
|
self.shape_dec.set_resolution(resolution)
|
||||||
return self.shape_dec(slat, return_subs=True)
|
return self.shape_dec(slat, return_subs=True)
|
||||||
|
|||||||
@ -1461,7 +1461,10 @@ class Trellis2(BaseModel):
|
|||||||
super().__init__(model_config, model_type, device, unet_model)
|
super().__init__(model_config, model_type, device, unet_model)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
return super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
|
embeds = kwargs.get("embeds")
|
||||||
|
out["embeds"] = comfy.conds.CONDRegular(embeds)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import comfy.model_management
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import PIL
|
import PIL
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from comfy.nested_tensor import NestedTensor
|
||||||
|
|
||||||
shape_slat_normalization = {
|
shape_slat_normalization = {
|
||||||
"mean": torch.tensor([
|
"mean": torch.tensor([
|
||||||
@ -131,7 +132,8 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae, resolution):
|
def execute(cls, samples: NestedTensor, vae, resolution):
|
||||||
|
samples = samples.tensors[0]
|
||||||
std = shape_slat_normalization["std"]
|
std = shape_slat_normalization["std"]
|
||||||
mean = shape_slat_normalization["mean"]
|
mean = shape_slat_normalization["mean"]
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
@ -157,9 +159,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae, shape_subs):
|
def execute(cls, samples, vae, shape_subs):
|
||||||
if shape_subs is None:
|
samples = samples.tensors[0]
|
||||||
raise ValueError("Shape subs must be provided for texture generation")
|
|
||||||
|
|
||||||
std = tex_slat_normalization["std"]
|
std = tex_slat_normalization["std"]
|
||||||
mean = tex_slat_normalization["mean"]
|
mean = tex_slat_normalization["mean"]
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
@ -167,6 +167,28 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="VaeDecodeStructureTrellis2",
|
||||||
|
category="latent/3d",
|
||||||
|
inputs=[
|
||||||
|
IO.Latent.Input("samples"),
|
||||||
|
IO.Vae.Input("vae"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Mesh.Output("structure_output"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, samples, vae):
|
||||||
|
decoder = vae.struct_dec
|
||||||
|
decoded = decoder(samples)>0
|
||||||
|
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
|
||||||
|
return coords
|
||||||
|
|
||||||
class Trellis2Conditioning(IO.ComfyNode):
|
class Trellis2Conditioning(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -189,8 +211,8 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
# could make 1024 an option
|
# could make 1024 an option
|
||||||
conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color)
|
conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color)
|
||||||
embeds = conditioning["cond_1024"] # should add that
|
embeds = conditioning["cond_1024"] # should add that
|
||||||
positive = [[conditioning["cond_512"], {embeds}]]
|
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||||
negative = [[conditioning["cond_neg"], {embeds}]]
|
negative = [[conditioning["cond_neg"], {"embeds": embeds}]]
|
||||||
return IO.NodeOutput(positive, negative)
|
return IO.NodeOutput(positive, negative)
|
||||||
|
|
||||||
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||||
@ -200,7 +222,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
node_id="EmptyLatentTrellis2",
|
node_id="EmptyLatentTrellis2",
|
||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Latent.Input("structure_output"),
|
IO.Mesh.Input("structure_output"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
@ -210,9 +232,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output):
|
def execute(cls, structure_output):
|
||||||
# i will see what i have to do here
|
# i will see what i have to do here
|
||||||
coords = structure_output or structure_output.coords
|
coords = structure_output # or structure_output.coords
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||||
|
latent = NestedTensor([latent])
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
@ -222,7 +245,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
node_id="EmptyLatentTrellis2",
|
node_id="EmptyLatentTrellis2",
|
||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Latent.Input("structure_output"),
|
IO.Mesh.Input("structure_output"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
@ -234,6 +257,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
# TODO
|
# TODO
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
|
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||||
|
latent = NestedTensor([latent])
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
@ -254,6 +278,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
def execute(cls, res, batch_size):
|
def execute(cls, res, batch_size):
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = torch.randn(batch_size, in_channels, res, res, res)
|
latent = torch.randn(batch_size, in_channels, res, res, res)
|
||||||
|
latent = NestedTensor([latent])
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
|
|
||||||
@ -266,7 +291,8 @@ class Trellis2Extension(ComfyExtension):
|
|||||||
EmptyStructureLatentTrellis2,
|
EmptyStructureLatentTrellis2,
|
||||||
EmptyTextureLatentTrellis2,
|
EmptyTextureLatentTrellis2,
|
||||||
VaeDecodeTextureTrellis,
|
VaeDecodeTextureTrellis,
|
||||||
VaeDecodeShapeTrellis
|
VaeDecodeShapeTrellis,
|
||||||
|
VaeDecodeStructureTrellis2
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user