mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
needed updates
This commit is contained in:
parent
6624939505
commit
3002708fe3
@ -7,6 +7,7 @@ from comfy.ldm.trellis2.attention import (
|
||||
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
||||
)
|
||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||
from comfy.nested_tensor import NestedTensor
|
||||
|
||||
class SparseGELU(nn.GELU):
|
||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||
@ -772,6 +773,11 @@ class SparseStructureFlowModel(nn.Module):
|
||||
|
||||
return h
|
||||
|
||||
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
||||
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
|
||||
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
|
||||
return t_new
|
||||
|
||||
class Trellis2(nn.Module):
|
||||
def __init__(self, resolution,
|
||||
in_channels = 32,
|
||||
@ -798,18 +804,25 @@ class Trellis2(nn.Module):
|
||||
args.pop("in_channels")
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
# TODO add mode
|
||||
mode = kwargs.get("mode", "shape_generation")
|
||||
if mode != 0:
|
||||
mode = "texture_generation" if mode == 2 else "shape_generation"
|
||||
else:
|
||||
def forward(self, x: NestedTensor, timestep, context, **kwargs):
|
||||
x = x.tensors[0]
|
||||
embeds = kwargs.get("embeds")
|
||||
if not hasattr(x, "feats"):
|
||||
mode = "structure_generation"
|
||||
else:
|
||||
if x.feats.shape[1] == 32:
|
||||
mode = "shape_generation"
|
||||
else:
|
||||
mode = "texture_generation"
|
||||
if mode == "shape_generation":
|
||||
out = self.img2shape(x, timestep, context)
|
||||
# TODO
|
||||
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
||||
elif mode == "texture_generation":
|
||||
out = self.shape2txt(x, timestep, context)
|
||||
else:
|
||||
else: # structure
|
||||
timestep = timestep_reshift(timestep)
|
||||
out = self.structure_model(x, timestep, context)
|
||||
|
||||
out = NestedTensor([out])
|
||||
out.generation_mode = mode
|
||||
return out
|
||||
|
||||
@ -9,6 +9,17 @@ import numpy as np
|
||||
from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
||||
|
||||
|
||||
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
||||
"""
|
||||
3D pixel shuffle.
|
||||
"""
|
||||
B, C, H, W, D = x.shape
|
||||
C_ = C // scale_factor**3
|
||||
x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
|
||||
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
|
||||
return x
|
||||
|
||||
class SparseConv3d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
||||
super(SparseConv3d, self).__init__()
|
||||
@ -1337,6 +1348,135 @@ def flexible_dual_grid_to_mesh(
|
||||
|
||||
return mesh_vertices, mesh_triangles
|
||||
|
||||
class ChannelLayerNorm32(LayerNorm32):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
DIM = x.dim()
|
||||
x = x.permute(0, *range(2, DIM), 1).contiguous()
|
||||
x = super().forward(x)
|
||||
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
|
||||
return x
|
||||
|
||||
class UpsampleBlock3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
mode = "conv",
|
||||
):
|
||||
assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
|
||||
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if mode == "conv":
|
||||
self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
|
||||
elif mode == "nearest":
|
||||
assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "conv"):
|
||||
x = self.conv(x)
|
||||
return pixel_shuffle_3d(x, 2)
|
||||
else:
|
||||
return F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
|
||||
def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
|
||||
return ChannelLayerNorm32(*args, **kwargs)
|
||||
|
||||
class ResBlock3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
norm_type = "layer",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.norm1 = norm_layer(norm_type, channels)
|
||||
self.norm2 = norm_layer(norm_type, self.out_channels)
|
||||
self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
|
||||
self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
|
||||
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
h = F.silu(h)
|
||||
h = self.conv1(h)
|
||||
h = self.norm2(h)
|
||||
h = F.silu(h)
|
||||
h = self.conv2(h)
|
||||
h = h + self.skip_connection(x)
|
||||
return h
|
||||
|
||||
|
||||
class SparseStructureDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels: int,
|
||||
latent_channels: int,
|
||||
num_res_blocks: int,
|
||||
channels: List[int],
|
||||
num_res_blocks_middle: int = 2,
|
||||
norm_type = "layer",
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.latent_channels = latent_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.channels = channels
|
||||
self.num_res_blocks_middle = num_res_blocks_middle
|
||||
self.norm_type = norm_type
|
||||
self.use_fp16 = use_fp16
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
|
||||
self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
|
||||
|
||||
self.middle_block = nn.Sequential(*[
|
||||
ResBlock3d(channels[0], channels[0])
|
||||
for _ in range(num_res_blocks_middle)
|
||||
])
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
for i, ch in enumerate(channels):
|
||||
self.blocks.extend([
|
||||
ResBlock3d(ch, ch)
|
||||
for _ in range(num_res_blocks)
|
||||
])
|
||||
if i < len(channels) - 1:
|
||||
self.blocks.append(
|
||||
UpsampleBlock3d(ch, channels[i+1])
|
||||
)
|
||||
|
||||
self.out_layer = nn.Sequential(
|
||||
norm_layer(norm_type, channels[-1]),
|
||||
nn.SiLU(),
|
||||
nn.Conv3d(channels[-1], out_channels, 3, padding=1)
|
||||
)
|
||||
|
||||
if use_fp16:
|
||||
self.convert_to_fp16()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.input_layer(x)
|
||||
|
||||
h = h.type(self.dtype)
|
||||
|
||||
h = self.middle_block(h)
|
||||
for block in self.blocks:
|
||||
h = block(h)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = self.out_layer(h)
|
||||
return h
|
||||
|
||||
class Vae(nn.Module):
|
||||
def __init__(self, config, operations=None):
|
||||
super().__init__()
|
||||
@ -1363,6 +1503,14 @@ class Vae(nn.Module):
|
||||
block_args=[{}, {}, {}, {}, {}],
|
||||
)
|
||||
|
||||
self.struct_dec = SparseStructureDecoder(
|
||||
out_channels=1,
|
||||
latent_channels=8,
|
||||
num_res_blocks=2,
|
||||
num_res_blocks_middle=2,
|
||||
channels=[512, 128, 32],
|
||||
)
|
||||
|
||||
def decode_shape_slat(self, slat, resolution: int):
|
||||
self.shape_dec.set_resolution(resolution)
|
||||
return self.shape_dec(slat, return_subs=True)
|
||||
|
||||
@ -1461,7 +1461,10 @@ class Trellis2(BaseModel):
|
||||
super().__init__(model_config, model_type, device, unet_model)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
return super().extra_conds(**kwargs)
|
||||
out = super().extra_conds(**kwargs)
|
||||
embeds = kwargs.get("embeds")
|
||||
out["embeds"] = comfy.conds.CONDRegular(embeds)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
|
||||
@ -6,6 +6,7 @@ import comfy.model_management
|
||||
from PIL import Image
|
||||
import PIL
|
||||
import numpy as np
|
||||
from comfy.nested_tensor import NestedTensor
|
||||
|
||||
shape_slat_normalization = {
|
||||
"mean": torch.tensor([
|
||||
@ -131,7 +132,8 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, resolution):
|
||||
def execute(cls, samples: NestedTensor, vae, resolution):
|
||||
samples = samples.tensors[0]
|
||||
std = shape_slat_normalization["std"]
|
||||
mean = shape_slat_normalization["mean"]
|
||||
samples = samples * std + mean
|
||||
@ -157,9 +159,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, shape_subs):
|
||||
if shape_subs is None:
|
||||
raise ValueError("Shape subs must be provided for texture generation")
|
||||
|
||||
samples = samples.tensors[0]
|
||||
std = tex_slat_normalization["std"]
|
||||
mean = tex_slat_normalization["mean"]
|
||||
samples = samples * std + mean
|
||||
@ -167,6 +167,28 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||
return mesh
|
||||
|
||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VaeDecodeStructureTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("structure_output"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae):
|
||||
decoder = vae.struct_dec
|
||||
decoded = decoder(samples)>0
|
||||
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
|
||||
return coords
|
||||
|
||||
class Trellis2Conditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -189,8 +211,8 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
# could make 1024 an option
|
||||
conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color)
|
||||
embeds = conditioning["cond_1024"] # should add that
|
||||
positive = [[conditioning["cond_512"], {embeds}]]
|
||||
negative = [[conditioning["cond_neg"], {embeds}]]
|
||||
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||
negative = [[conditioning["cond_neg"], {"embeds": embeds}]]
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
@ -200,7 +222,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
node_id="EmptyLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("structure_output"),
|
||||
IO.Mesh.Input("structure_output"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
@ -210,9 +232,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, structure_output):
|
||||
# i will see what i have to do here
|
||||
coords = structure_output or structure_output.coords
|
||||
coords = structure_output # or structure_output.coords
|
||||
in_channels = 32
|
||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||
latent = NestedTensor([latent])
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
@ -222,7 +245,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
node_id="EmptyLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("structure_output"),
|
||||
IO.Mesh.Input("structure_output"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
@ -234,6 +257,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
# TODO
|
||||
in_channels = 32
|
||||
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||
latent = NestedTensor([latent])
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
@ -254,6 +278,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
def execute(cls, res, batch_size):
|
||||
in_channels = 32
|
||||
latent = torch.randn(batch_size, in_channels, res, res, res)
|
||||
latent = NestedTensor([latent])
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
|
||||
@ -266,7 +291,8 @@ class Trellis2Extension(ComfyExtension):
|
||||
EmptyStructureLatentTrellis2,
|
||||
EmptyTextureLatentTrellis2,
|
||||
VaeDecodeTextureTrellis,
|
||||
VaeDecodeShapeTrellis
|
||||
VaeDecodeShapeTrellis,
|
||||
VaeDecodeStructureTrellis2
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user