mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
vae shape decode fixes
This commit is contained in:
parent
a2c8a7aab5
commit
f31c2e1d1d
@ -802,7 +802,7 @@ class Trellis2(nn.Module):
|
||||
# may remove the else if texture doesn't require special handling
|
||||
batched_coords = coords
|
||||
feats_flat = x
|
||||
x = SparseTensor(feats=feats_flat, coords=batched_coords)
|
||||
x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
||||
|
||||
if mode == "shape_generation":
|
||||
# TODO
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
||||
from fractions import Fraction
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import comfy.model_management
|
||||
import torch.nn.functional as F
|
||||
from fractions import Fraction
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
||||
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
||||
|
||||
|
||||
@ -55,11 +56,12 @@ def sparse_conv3d_forward(self, x):
|
||||
Co, Kd, Kh, Kw, Ci = self.weight.shape
|
||||
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
|
||||
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
|
||||
x = x.to(self.weight.dtype).to(self.weight.device)
|
||||
|
||||
out, neighbor_cache_ = sparse_submanifold_conv3d(
|
||||
x.feats,
|
||||
x.coords,
|
||||
torch.Size([*x.shape, *x.spatial_shape]),
|
||||
x.spatial_shape,
|
||||
self.weight,
|
||||
self.bias,
|
||||
neighbor_cache,
|
||||
@ -100,7 +102,8 @@ class SparseConvNeXtBlock3d(nn.Module):
|
||||
|
||||
def _forward(self, x):
|
||||
h = self.conv(x)
|
||||
h = h.replace(self.norm(h.feats))
|
||||
norm = self.norm.to(torch.float32)
|
||||
h = h.replace(norm(h.feats))
|
||||
h = h.replace(self.mlp(h.feats))
|
||||
return h + x
|
||||
|
||||
@ -208,13 +211,15 @@ class SparseResBlockC2S3d(nn.Module):
|
||||
def forward(self, x, subdiv = None):
|
||||
if self.pred_subdiv:
|
||||
subdiv = self.to_subdiv(x)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
norm1 = self.norm1.to(torch.float32)
|
||||
norm2 = self.norm2.to(torch.float32)
|
||||
h = x.replace(norm1(x.feats))
|
||||
h = h.replace(F.silu(h.feats))
|
||||
h = self.conv1(h)
|
||||
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
|
||||
h = self.updown(h, subdiv_binarized)
|
||||
x = self.updown(x, subdiv_binarized)
|
||||
h = h.replace(self.norm2(h.feats))
|
||||
h = h.replace(norm2(h.feats))
|
||||
h = h.replace(F.silu(h.feats))
|
||||
h = self.conv2(h)
|
||||
h = h + self.skip_connection(x)
|
||||
@ -1139,6 +1144,9 @@ class SparseUnetVaeDecoder(nn.Module):
|
||||
|
||||
def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor:
|
||||
|
||||
dtype = next(self.from_latent.parameters()).dtype
|
||||
device = next(self.from_latent.parameters()).device
|
||||
x.feats = x.feats.to(dtype).to(device)
|
||||
h = self.from_latent(x)
|
||||
h = h.type(self.dtype)
|
||||
subs = []
|
||||
@ -1152,7 +1160,7 @@ class SparseUnetVaeDecoder(nn.Module):
|
||||
h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
|
||||
else:
|
||||
h = block(h)
|
||||
h = h.type(x.dtype)
|
||||
h = h.type(x.feats.dtype)
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.output_layer(h)
|
||||
if return_subs:
|
||||
@ -1520,6 +1528,8 @@ class Vae(nn.Module):
|
||||
|
||||
def decode_shape_slat(self, slat, resolution: int):
|
||||
self.shape_dec.set_resolution(resolution)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.shape_dec = self.shape_dec.to(device)
|
||||
return self.shape_dec(slat, return_subs=True)
|
||||
|
||||
def decode_tex_slat(self, slat, subs):
|
||||
|
||||
@ -513,8 +513,8 @@ class VAE:
|
||||
init_txt_model = True
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
from comfy.utils import ProgressBar
|
||||
import torch.nn.functional as TF
|
||||
import comfy.model_management
|
||||
from comfy.utils import ProgressBar
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -135,6 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Voxel.Input("structure_output"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Int.Input("resolution", tooltip="Shape Generation Resolution"),
|
||||
],
|
||||
@ -145,11 +147,14 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, resolution):
|
||||
def execute(cls, samples, structure_output, vae, resolution):
|
||||
vae = vae.first_stage_model
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
samples = samples["samples"]
|
||||
std = shape_slat_normalization["std"].to(samples)
|
||||
mean = shape_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords)
|
||||
samples = samples * std + mean
|
||||
|
||||
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
||||
@ -163,6 +168,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Voxel.Input("structure_output"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.AnyType.Input("shape_subs"),
|
||||
],
|
||||
@ -172,11 +178,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, shape_subs):
|
||||
def execute(cls, samples, structure_output, vae, shape_subs):
|
||||
vae = vae.first_stage_model
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
samples = samples["samples"]
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
mean = tex_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords)
|
||||
samples = samples * std + mean
|
||||
|
||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user