mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
vae shape decode fixes
This commit is contained in:
parent
a2c8a7aab5
commit
f31c2e1d1d
@ -802,7 +802,7 @@ class Trellis2(nn.Module):
|
|||||||
# may remove the else if texture doesn't require special handling
|
# may remove the else if texture doesn't require special handling
|
||||||
batched_coords = coords
|
batched_coords = coords
|
||||||
feats_flat = x
|
feats_flat = x
|
||||||
x = SparseTensor(feats=feats_flat, coords=batched_coords)
|
x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
||||||
|
|
||||||
if mode == "shape_generation":
|
if mode == "shape_generation":
|
||||||
# TODO
|
# TODO
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
|
||||||
from fractions import Fraction
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import comfy.model_management
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fractions import Fraction
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
||||||
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
||||||
|
|
||||||
|
|
||||||
@ -55,11 +56,12 @@ def sparse_conv3d_forward(self, x):
|
|||||||
Co, Kd, Kh, Kw, Ci = self.weight.shape
|
Co, Kd, Kh, Kw, Ci = self.weight.shape
|
||||||
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
|
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
|
||||||
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
|
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
|
||||||
|
x = x.to(self.weight.dtype).to(self.weight.device)
|
||||||
|
|
||||||
out, neighbor_cache_ = sparse_submanifold_conv3d(
|
out, neighbor_cache_ = sparse_submanifold_conv3d(
|
||||||
x.feats,
|
x.feats,
|
||||||
x.coords,
|
x.coords,
|
||||||
torch.Size([*x.shape, *x.spatial_shape]),
|
x.spatial_shape,
|
||||||
self.weight,
|
self.weight,
|
||||||
self.bias,
|
self.bias,
|
||||||
neighbor_cache,
|
neighbor_cache,
|
||||||
@ -100,7 +102,8 @@ class SparseConvNeXtBlock3d(nn.Module):
|
|||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
h = self.conv(x)
|
h = self.conv(x)
|
||||||
h = h.replace(self.norm(h.feats))
|
norm = self.norm.to(torch.float32)
|
||||||
|
h = h.replace(norm(h.feats))
|
||||||
h = h.replace(self.mlp(h.feats))
|
h = h.replace(self.mlp(h.feats))
|
||||||
return h + x
|
return h + x
|
||||||
|
|
||||||
@ -208,13 +211,15 @@ class SparseResBlockC2S3d(nn.Module):
|
|||||||
def forward(self, x, subdiv = None):
|
def forward(self, x, subdiv = None):
|
||||||
if self.pred_subdiv:
|
if self.pred_subdiv:
|
||||||
subdiv = self.to_subdiv(x)
|
subdiv = self.to_subdiv(x)
|
||||||
h = x.replace(self.norm1(x.feats))
|
norm1 = self.norm1.to(torch.float32)
|
||||||
|
norm2 = self.norm2.to(torch.float32)
|
||||||
|
h = x.replace(norm1(x.feats))
|
||||||
h = h.replace(F.silu(h.feats))
|
h = h.replace(F.silu(h.feats))
|
||||||
h = self.conv1(h)
|
h = self.conv1(h)
|
||||||
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
|
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
|
||||||
h = self.updown(h, subdiv_binarized)
|
h = self.updown(h, subdiv_binarized)
|
||||||
x = self.updown(x, subdiv_binarized)
|
x = self.updown(x, subdiv_binarized)
|
||||||
h = h.replace(self.norm2(h.feats))
|
h = h.replace(norm2(h.feats))
|
||||||
h = h.replace(F.silu(h.feats))
|
h = h.replace(F.silu(h.feats))
|
||||||
h = self.conv2(h)
|
h = self.conv2(h)
|
||||||
h = h + self.skip_connection(x)
|
h = h + self.skip_connection(x)
|
||||||
@ -1139,6 +1144,9 @@ class SparseUnetVaeDecoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor:
|
def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor:
|
||||||
|
|
||||||
|
dtype = next(self.from_latent.parameters()).dtype
|
||||||
|
device = next(self.from_latent.parameters()).device
|
||||||
|
x.feats = x.feats.to(dtype).to(device)
|
||||||
h = self.from_latent(x)
|
h = self.from_latent(x)
|
||||||
h = h.type(self.dtype)
|
h = h.type(self.dtype)
|
||||||
subs = []
|
subs = []
|
||||||
@ -1152,7 +1160,7 @@ class SparseUnetVaeDecoder(nn.Module):
|
|||||||
h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
|
h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
|
||||||
else:
|
else:
|
||||||
h = block(h)
|
h = block(h)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.feats.dtype)
|
||||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||||
h = self.output_layer(h)
|
h = self.output_layer(h)
|
||||||
if return_subs:
|
if return_subs:
|
||||||
@ -1520,6 +1528,8 @@ class Vae(nn.Module):
|
|||||||
|
|
||||||
def decode_shape_slat(self, slat, resolution: int):
|
def decode_shape_slat(self, slat, resolution: int):
|
||||||
self.shape_dec.set_resolution(resolution)
|
self.shape_dec.set_resolution(resolution)
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
self.shape_dec = self.shape_dec.to(device)
|
||||||
return self.shape_dec(slat, return_subs=True)
|
return self.shape_dec(slat, return_subs=True)
|
||||||
|
|
||||||
def decode_tex_slat(self, slat, subs):
|
def decode_tex_slat(self, slat, subs):
|
||||||
|
|||||||
@ -513,8 +513,8 @@ class VAE:
|
|||||||
init_txt_model = True
|
init_txt_model = True
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
# TODO
|
# TODO
|
||||||
self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
|
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
|
||||||
elif "decoder.conv_in.weight" in sd:
|
elif "decoder.conv_in.weight" in sd:
|
||||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
|
from comfy.ldm.trellis2.vae import SparseTensor
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
import torch.nn.functional as TF
|
import torch.nn.functional as TF
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.utils import ProgressBar
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -135,6 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Latent.Input("samples"),
|
IO.Latent.Input("samples"),
|
||||||
|
IO.Voxel.Input("structure_output"),
|
||||||
IO.Vae.Input("vae"),
|
IO.Vae.Input("vae"),
|
||||||
IO.Int.Input("resolution", tooltip="Shape Generation Resolution"),
|
IO.Int.Input("resolution", tooltip="Shape Generation Resolution"),
|
||||||
],
|
],
|
||||||
@ -145,11 +147,14 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae, resolution):
|
def execute(cls, samples, structure_output, vae, resolution):
|
||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
std = shape_slat_normalization["std"].to(samples)
|
std = shape_slat_normalization["std"].to(samples)
|
||||||
mean = shape_slat_normalization["mean"].to(samples)
|
mean = shape_slat_normalization["mean"].to(samples)
|
||||||
|
samples = SparseTensor(feats = samples, coords=coords)
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
||||||
@ -163,6 +168,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Latent.Input("samples"),
|
IO.Latent.Input("samples"),
|
||||||
|
IO.Voxel.Input("structure_output"),
|
||||||
IO.Vae.Input("vae"),
|
IO.Vae.Input("vae"),
|
||||||
IO.AnyType.Input("shape_subs"),
|
IO.AnyType.Input("shape_subs"),
|
||||||
],
|
],
|
||||||
@ -172,11 +178,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae, shape_subs):
|
def execute(cls, samples, structure_output, vae, shape_subs):
|
||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
std = tex_slat_normalization["std"].to(samples)
|
std = tex_slat_normalization["std"].to(samples)
|
||||||
mean = tex_slat_normalization["mean"].to(samples)
|
mean = tex_slat_normalization["mean"].to(samples)
|
||||||
|
samples = SparseTensor(feats = samples, coords=coords)
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user