mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
removed unnecessary code + optimizations + progres
This commit is contained in:
parent
f31c2e1d1d
commit
39270fdca9
@ -2,7 +2,7 @@
|
||||
|
||||
import math
|
||||
import torch
|
||||
from typing import Dict, Callable
|
||||
from typing import Callable
|
||||
import logging
|
||||
|
||||
NO_TRITON = False
|
||||
@ -201,13 +201,13 @@ class TorchHashMap:
|
||||
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
||||
device = keys.device
|
||||
# use long for searchsorted
|
||||
self.sorted_keys, order = torch.sort(keys.long())
|
||||
self.sorted_vals = values.long()[order]
|
||||
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
||||
self.sorted_vals = values.to(torch.long)[order]
|
||||
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
||||
self._n = self.sorted_keys.numel()
|
||||
|
||||
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||
flat = flat_keys.long()
|
||||
flat = flat_keys.to(torch.long)
|
||||
if self._n == 0:
|
||||
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||
idx = torch.searchsorted(self.sorted_keys, flat)
|
||||
@ -225,44 +225,35 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
|
||||
device = neighbor_map.device
|
||||
N, V = neighbor_map.shape
|
||||
|
||||
sentinel = UINT32_SENTINEL
|
||||
|
||||
neigh = neighbor_map.to(torch.long)
|
||||
sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
neigh_map_T = neigh.t().reshape(-1)
|
||||
|
||||
neigh_map_T = neighbor_map.t().reshape(-1)
|
||||
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
|
||||
|
||||
mask = (neigh != sentinel).to(torch.long)
|
||||
mask = (neighbor_map != sentinel).to(torch.long)
|
||||
gray_code = torch.zeros(N, dtype=torch.long, device=device)
|
||||
|
||||
powers = (1 << torch.arange(V, dtype=torch.long, device=device))
|
||||
for v in range(V):
|
||||
gray_code |= (mask[:, v] << v)
|
||||
|
||||
gray_long = (mask * powers).sum(dim=1)
|
||||
|
||||
gray_code = gray_long.to(torch.int32)
|
||||
|
||||
binary_long = gray_long.clone()
|
||||
binary_code = gray_code.clone()
|
||||
for v in range(1, V):
|
||||
binary_long ^= (gray_long >> v)
|
||||
binary_code = binary_long.to(torch.int32)
|
||||
binary_code ^= (gray_code >> v)
|
||||
|
||||
sorted_idx = torch.argsort(binary_code)
|
||||
|
||||
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,)
|
||||
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0)
|
||||
|
||||
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
|
||||
|
||||
if total_valid_signal > 0:
|
||||
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
|
||||
to = (prefix_sum_neighbor_mask[pos] - 1).long()
|
||||
|
||||
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||
|
||||
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
|
||||
|
||||
to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long)
|
||||
|
||||
valid_signal_i[to] = (pos % N).to(torch.long)
|
||||
|
||||
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
|
||||
else:
|
||||
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
|
||||
@ -272,9 +263,7 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
|
||||
seg[0] = 0
|
||||
if V > 0:
|
||||
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
|
||||
seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long)
|
||||
else:
|
||||
pass
|
||||
seg[1:] = prefix_sum_neighbor_mask[idxs]
|
||||
|
||||
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
|
||||
|
||||
@ -295,40 +284,41 @@ def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||
gray_code: torch.Tensor, # [N], int32-like (non-negative)
|
||||
sorted_idx: torch.Tensor, # [N], long (indexing into gray_code)
|
||||
gray_code: torch.Tensor,
|
||||
sorted_idx: torch.Tensor,
|
||||
block_size: int
|
||||
):
|
||||
device = gray_code.device
|
||||
N = gray_code.numel()
|
||||
|
||||
# num of blocks (same as CUDA)
|
||||
num_blocks = (N + block_size - 1) // block_size
|
||||
|
||||
# Ensure dtypes
|
||||
gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast
|
||||
sorted_idx = sorted_idx.to(torch.long)
|
||||
|
||||
# 1) Group gray_code by blocks and compute OR across each block
|
||||
# pad the last block with zeros if necessary so we can reshape
|
||||
pad = num_blocks * block_size - N
|
||||
if pad > 0:
|
||||
pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device)
|
||||
gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0)
|
||||
pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device)
|
||||
gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0)
|
||||
else:
|
||||
gray_padded = gray_long[sorted_idx]
|
||||
gray_padded = gray_code[sorted_idx]
|
||||
|
||||
# reshape to (num_blocks, block_size) and compute bitwise_or across dim=1
|
||||
gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries
|
||||
# reduce with bitwise_or
|
||||
reduced_code = gray_blocks[:, 0].clone()
|
||||
for i in range(1, block_size):
|
||||
reduced_code |= gray_blocks[:, i]
|
||||
reduced_code = reduced_code.to(torch.int32) # match CUDA int32
|
||||
gray_blocks = gray_padded.view(num_blocks, block_size)
|
||||
|
||||
reduced_code = gray_blocks
|
||||
while reduced_code.shape[1] > 1:
|
||||
half = reduced_code.shape[1] // 2
|
||||
remainder = reduced_code.shape[1] % 2
|
||||
|
||||
left = reduced_code[:, :half * 2:2]
|
||||
right = reduced_code[:, 1:half * 2:2]
|
||||
merged = left | right
|
||||
|
||||
if remainder:
|
||||
reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1)
|
||||
else:
|
||||
reduced_code = merged
|
||||
|
||||
reduced_code = reduced_code.squeeze(1)
|
||||
|
||||
seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32)
|
||||
|
||||
# 2) compute seglen (popcount per reduced_code) and seg (prefix sum)
|
||||
seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks]
|
||||
# seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i
|
||||
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
|
||||
seg[0] = 0
|
||||
if num_blocks > 0:
|
||||
@ -336,30 +326,20 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||
|
||||
total = int(seg[-1].item())
|
||||
|
||||
# 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row
|
||||
if total == 0:
|
||||
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
|
||||
return valid_kernel_idx, seg
|
||||
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||
|
||||
max_val = int(reduced_code.max().item())
|
||||
V = max_val.bit_length() if max_val > 0 else 0
|
||||
# If you know V externally, pass it instead or set here explicitly.
|
||||
V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0
|
||||
|
||||
if V == 0:
|
||||
# no bits set anywhere
|
||||
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
|
||||
return valid_kernel_idx, seg
|
||||
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||
|
||||
# build mask of shape (num_blocks, V): True where bit is set
|
||||
bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V]
|
||||
# shifted = reduced_code[:, None] >> bit_pos[None, :]
|
||||
shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0)
|
||||
bits = (shifted & 1).to(torch.bool) # (num_blocks, V)
|
||||
bit_pos = torch.arange(0, V, dtype=torch.int32, device=device)
|
||||
shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0)
|
||||
bits = (shifted & 1).to(torch.bool)
|
||||
|
||||
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
|
||||
|
||||
valid_positions = positions[bits]
|
||||
valid_kernel_idx = valid_positions.to(torch.int32).contiguous()
|
||||
valid_kernel_idx = positions[bits].to(torch.int32).contiguous()
|
||||
|
||||
return valid_kernel_idx, seg
|
||||
|
||||
@ -425,35 +405,6 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache
|
||||
|
||||
return out, neighbor
|
||||
|
||||
class Voxel:
|
||||
def __init__(
|
||||
self,
|
||||
origin: list,
|
||||
voxel_size: float,
|
||||
coords: torch.Tensor = None,
|
||||
attrs: torch.Tensor = None,
|
||||
layout = None,
|
||||
device: torch.device = None
|
||||
):
|
||||
if layout is None:
|
||||
layout = {}
|
||||
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
|
||||
self.voxel_size = voxel_size
|
||||
self.coords = coords
|
||||
self.attrs = attrs
|
||||
self.layout = layout
|
||||
self.device = device
|
||||
|
||||
@property
|
||||
def position(self):
|
||||
return (self.coords + 0.5) * self.voxel_size + self.origin[None, :]
|
||||
|
||||
def split_attrs(self):
|
||||
return {
|
||||
k: self.attrs[:, self.layout[k]]
|
||||
for k in self.layout
|
||||
}
|
||||
|
||||
class Mesh:
|
||||
def __init__(self,
|
||||
vertices,
|
||||
@ -480,35 +431,3 @@ class Mesh:
|
||||
|
||||
def cpu(self):
|
||||
return self.to('cpu')
|
||||
|
||||
class MeshWithVoxel(Mesh, Voxel):
|
||||
def __init__(self,
|
||||
vertices: torch.Tensor,
|
||||
faces: torch.Tensor,
|
||||
origin: list,
|
||||
voxel_size: float,
|
||||
coords: torch.Tensor,
|
||||
attrs: torch.Tensor,
|
||||
voxel_shape: torch.Size,
|
||||
layout: Dict = {},
|
||||
):
|
||||
self.vertices = vertices.float()
|
||||
self.faces = faces.int()
|
||||
self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device)
|
||||
self.voxel_size = voxel_size
|
||||
self.coords = coords
|
||||
self.attrs = attrs
|
||||
self.voxel_shape = voxel_shape
|
||||
self.layout = layout
|
||||
|
||||
def to(self, device, non_blocking=False):
|
||||
return MeshWithVoxel(
|
||||
self.vertices.to(device, non_blocking=non_blocking),
|
||||
self.faces.to(device, non_blocking=non_blocking),
|
||||
self.origin.tolist(),
|
||||
self.voxel_size,
|
||||
self.coords.to(device, non_blocking=non_blocking),
|
||||
self.attrs.to(device, non_blocking=non_blocking),
|
||||
self.voxel_shape,
|
||||
self.layout,
|
||||
)
|
||||
|
||||
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from fractions import Fraction
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
||||
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
||||
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, sparse_submanifold_conv3d
|
||||
|
||||
|
||||
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
||||
@ -210,6 +210,8 @@ class SparseResBlockC2S3d(nn.Module):
|
||||
|
||||
def forward(self, x, subdiv = None):
|
||||
if self.pred_subdiv:
|
||||
dtype = next(self.to_subdiv.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
subdiv = self.to_subdiv(x)
|
||||
norm1 = self.norm1.to(torch.float32)
|
||||
norm2 = self.norm2.to(torch.float32)
|
||||
@ -987,114 +989,7 @@ def convert_module_to_f16(l):
|
||||
for p in l.parameters():
|
||||
p.data = p.data.half()
|
||||
|
||||
|
||||
|
||||
class SparseUnetVaeEncoder(nn.Module):
|
||||
"""
|
||||
Sparse Swin Transformer Unet VAE model.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: List[int],
|
||||
latent_channels: int,
|
||||
num_blocks: List[int],
|
||||
block_type: List[str],
|
||||
down_block_type: List[str],
|
||||
block_args: List[Dict[str, Any]],
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
|
||||
self.input_layer = SparseLinear(in_channels, model_channels[0])
|
||||
self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
for i in range(len(num_blocks)):
|
||||
self.blocks.append(nn.ModuleList([]))
|
||||
for j in range(num_blocks[i]):
|
||||
self.blocks[-1].append(
|
||||
globals()[block_type[i]](
|
||||
model_channels[i],
|
||||
**block_args[i],
|
||||
)
|
||||
)
|
||||
if i < len(num_blocks) - 1:
|
||||
self.blocks[-1].append(
|
||||
globals()[down_block_type[i]](
|
||||
model_channels[i],
|
||||
model_channels[i+1],
|
||||
**block_args[i],
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False):
|
||||
h = self.input_layer(x)
|
||||
h = h.type(self.dtype)
|
||||
for i, res in enumerate(self.blocks):
|
||||
for j, block in enumerate(res):
|
||||
h = block(h)
|
||||
h = h.type(x.dtype)
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.to_latent(h)
|
||||
|
||||
# Sample from the posterior distribution
|
||||
mean, logvar = h.feats.chunk(2, dim=-1)
|
||||
if sample_posterior:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
z = mean + std * torch.randn_like(std)
|
||||
else:
|
||||
z = mean
|
||||
z = h.replace(z)
|
||||
|
||||
if return_raw:
|
||||
return z, mean, logvar
|
||||
else:
|
||||
return z
|
||||
|
||||
|
||||
|
||||
class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
model_channels: List[int],
|
||||
latent_channels: int,
|
||||
num_blocks: List[int],
|
||||
block_type: List[str],
|
||||
down_block_type: List[str],
|
||||
block_args: List[Dict[str, Any]],
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
6,
|
||||
model_channels,
|
||||
latent_channels,
|
||||
num_blocks,
|
||||
block_type,
|
||||
down_block_type,
|
||||
block_args,
|
||||
use_fp16,
|
||||
)
|
||||
|
||||
def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False):
|
||||
x = vertices.replace(torch.cat([
|
||||
vertices.feats - 0.5,
|
||||
intersected.feats.float() - 0.5,
|
||||
], dim=1))
|
||||
return super().forward(x, sample_posterior, return_raw)
|
||||
|
||||
class SparseUnetVaeDecoder(nn.Module):
|
||||
"""
|
||||
Sparse Swin Transformer Unet VAE model.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
out_channels: int,
|
||||
@ -1218,10 +1113,10 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
|
||||
N = coords.shape[0]
|
||||
# compute flat keys for all coords (prepend batch 0 same as original code)
|
||||
b = torch.zeros((N,), dtype=torch.long, device=device)
|
||||
x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long()
|
||||
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
|
||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
|
||||
values = torch.arange(N, dtype=torch.long, device=device)
|
||||
values = torch.arange(N, dtype=torch.int32, device=device)
|
||||
DEFAULT_VAL = 0xffffffff # sentinel used in original code
|
||||
return TorchHashMap(flat_keys, values, DEFAULT_VAL)
|
||||
|
||||
@ -1295,13 +1190,12 @@ def flexible_dual_grid_to_mesh(
|
||||
|
||||
# Extract mesh
|
||||
N = dual_vertices.shape[0]
|
||||
mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5
|
||||
|
||||
if hashmap_builder is None:
|
||||
# build local TorchHashMap
|
||||
device = coords.device
|
||||
b = torch.zeros((N,), dtype=torch.long, device=device)
|
||||
x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long()
|
||||
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
|
||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
|
||||
values = torch.arange(N, dtype=torch.long, device=device)
|
||||
@ -1316,9 +1210,9 @@ def flexible_dual_grid_to_mesh(
|
||||
M = connected_voxel.shape[0]
|
||||
# flatten connected voxel coords and lookup
|
||||
conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device)
|
||||
conn_x = connected_voxel.reshape(-1, 3)[:, 0].long()
|
||||
conn_y = connected_voxel.reshape(-1, 3)[:, 1].long()
|
||||
conn_z = connected_voxel.reshape(-1, 3)[:, 2].long()
|
||||
conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32)
|
||||
conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32)
|
||||
conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32)
|
||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z
|
||||
|
||||
@ -1526,17 +1420,18 @@ class Vae(nn.Module):
|
||||
channels=[512, 128, 32],
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_shape_slat(self, slat, resolution: int):
|
||||
self.shape_dec.set_resolution(resolution)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
self.shape_dec = self.shape_dec.to(device)
|
||||
return self.shape_dec(slat, return_subs=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_tex_slat(self, slat, subs):
|
||||
if self.txt_dec is None:
|
||||
raise ValueError("Checkpoint doesn't include texture model")
|
||||
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
|
||||
|
||||
# shouldn't be called (placeholder)
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
@ -1546,17 +1441,4 @@ class Vae(nn.Module):
|
||||
):
|
||||
meshes, subs = self.decode_shape_slat(shape_slat, resolution)
|
||||
tex_voxels = self.decode_tex_slat(tex_slat, subs)
|
||||
out_mesh = []
|
||||
for m, v in zip(meshes, tex_voxels):
|
||||
out_mesh.append(
|
||||
MeshWithVoxel(
|
||||
m.vertices, m.faces,
|
||||
origin = [-0.5, -0.5, -0.5],
|
||||
voxel_size = 1 / resolution,
|
||||
coords = v.coords[:, 1:],
|
||||
attrs = v.feats,
|
||||
voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
|
||||
layout=self.pbr_attr_layout
|
||||
)
|
||||
)
|
||||
return out_mesh
|
||||
return tex_voxels
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy.utils import ProgressBar, lanczos
|
||||
import torch.nn.functional as TF
|
||||
import comfy.model_management
|
||||
from PIL import Image
|
||||
@ -102,9 +102,7 @@ def run_conditioning(model, image, mask, include_1024 = True, background_color =
|
||||
cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb)
|
||||
|
||||
def prepare_tensor(img, size):
|
||||
resized = torch.nn.functional.interpolate(
|
||||
img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False
|
||||
)
|
||||
resized = lanczos(img.unsqueeze(0), size, size)
|
||||
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||
|
||||
model_internal.image_size = 512
|
||||
@ -148,10 +146,16 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, structure_output, vae, resolution):
|
||||
|
||||
patcher = vae.patcher
|
||||
device = comfy.model_management.get_torch_device()
|
||||
comfy.model_management.load_model_gpu(patcher)
|
||||
|
||||
vae = vae.first_stage_model
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
samples = samples["samples"]
|
||||
samples = samples.squeeze(-1).transpose(1, 2).to(device)
|
||||
std = shape_slat_normalization["std"].to(samples)
|
||||
mean = shape_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords)
|
||||
@ -179,10 +183,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, structure_output, vae, shape_subs):
|
||||
|
||||
patcher = vae.patcher
|
||||
device = comfy.model_management.get_torch_device()
|
||||
comfy.model_management.load_model_gpu(patcher)
|
||||
|
||||
vae = vae.first_stage_model
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
samples = samples["samples"]
|
||||
samples = samples.squeeze(-1).transpose(1, 2).to(device)
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
mean = tex_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user