mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-17 05:52:31 +08:00
fixes
This commit is contained in:
parent
1fde60b2bc
commit
253ee4c02c
@ -44,9 +44,6 @@ class DINOv3ViTAttention(nn.Module):
|
|||||||
self.num_heads = num_attention_heads
|
self.num_heads = num_attention_heads
|
||||||
self.head_dim = self.embed_dim // self.num_heads
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
|
||||||
self.scaling = self.head_dim**-0.5
|
|
||||||
self.is_causal = False
|
|
||||||
|
|
||||||
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
||||||
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
@ -251,7 +248,7 @@ class DINOv3ViTModel(nn.Module):
|
|||||||
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
|
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
|
||||||
dtype=dtype, device=device, operations=operations)
|
dtype=dtype, device=device, operations=operations)
|
||||||
for _ in range(num_hidden_layers)])
|
for _ in range(num_hidden_layers)])
|
||||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.embeddings.patch_embeddings
|
return self.embeddings.patch_embeddings
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
from typing import Dict, Callable
|
from typing import Dict, Callable
|
||||||
|
|
||||||
NO_TRITION = False
|
NO_TRITON = False
|
||||||
try:
|
try:
|
||||||
allow_tf32 = torch.cuda.is_tf32_supported()
|
allow_tf32 = torch.cuda.is_tf32_supported()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -115,8 +115,8 @@ try:
|
|||||||
allow_tf32=allow_tf32,
|
allow_tf32=allow_tf32,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
except:
|
except Exception:
|
||||||
NO_TRITION = True
|
NO_TRITON = True
|
||||||
|
|
||||||
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
||||||
# offsets in same order as CUDA kernel
|
# offsets in same order as CUDA kernel
|
||||||
@ -364,6 +364,8 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
|||||||
|
|
||||||
|
|
||||||
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
||||||
|
if NO_TRITON: # TODO
|
||||||
|
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
|
||||||
if len(shape) == 5:
|
if len(shape) == 5:
|
||||||
N, C, W, H, D = shape
|
N, C, W, H, D = shape
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -697,8 +697,6 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||||
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
||||||
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
|
|
||||||
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
|
|
||||||
|
|
||||||
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
|
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
@ -746,7 +744,8 @@ class Trellis2(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
# for some reason it passes num_heads = -1
|
# for some reason it passes num_heads = -1
|
||||||
num_heads = 12
|
if num_heads == -1:
|
||||||
|
num_heads = 12
|
||||||
args = {
|
args = {
|
||||||
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
||||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||||
@ -763,8 +762,10 @@ class Trellis2(nn.Module):
|
|||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
# FIXME: should find a way to distinguish between 512/1024 models
|
# FIXME: should find a way to distinguish between 512/1024 models
|
||||||
# currently assumes 1024
|
# currently assumes 1024
|
||||||
transformer_options = kwargs.get("transformer_options")
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
|
if embeds is None:
|
||||||
|
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||||
#_, cond = context.chunk(2) # TODO
|
#_, cond = context.chunk(2) # TODO
|
||||||
cond = embeds.chunk(2)[0]
|
cond = embeds.chunk(2)[0]
|
||||||
context = torch.cat([torch.zeros_like(cond), cond])
|
context = torch.cat([torch.zeros_like(cond), cond])
|
||||||
@ -807,6 +808,8 @@ class Trellis2(nn.Module):
|
|||||||
# TODO
|
# TODO
|
||||||
out = self.img2shape(x, timestep, context)
|
out = self.img2shape(x, timestep, context)
|
||||||
elif mode == "texture_generation":
|
elif mode == "texture_generation":
|
||||||
|
if self.shape2txt is None:
|
||||||
|
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||||
else: # structure
|
else: # structure
|
||||||
timestep = timestep_reshift(timestep)
|
timestep = timestep_reshift(timestep)
|
||||||
|
|||||||
@ -1522,6 +1522,8 @@ class Vae(nn.Module):
|
|||||||
return self.shape_dec(slat, return_subs=True)
|
return self.shape_dec(slat, return_subs=True)
|
||||||
|
|
||||||
def decode_tex_slat(self, slat, subs):
|
def decode_tex_slat(self, slat, subs):
|
||||||
|
if self.txt_dec is None:
|
||||||
|
raise ValueError("Checkpoint doesn't include texture model")
|
||||||
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
|
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
import torch
|
import torch.nn.functional as TF
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
|
||||||
shape_slat_normalization = {
|
shape_slat_normalization = {
|
||||||
"mean": torch.tensor([
|
"mean": torch.tensor([
|
||||||
@ -145,11 +147,11 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
def execute(cls, samples, vae, resolution):
|
def execute(cls, samples, vae, resolution):
|
||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
std = shape_slat_normalization["std"]
|
std = shape_slat_normalization["std"].to(samples)
|
||||||
mean = shape_slat_normalization["mean"]
|
mean = shape_slat_normalization["mean"].to(samples)
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
mesh, subs = vae.decode_shape_slat(resolution, samples)
|
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
||||||
return IO.NodeOutput(mesh, subs)
|
return IO.NodeOutput(mesh, subs)
|
||||||
|
|
||||||
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||||
@ -172,8 +174,8 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
def execute(cls, samples, vae, shape_subs):
|
def execute(cls, samples, vae, shape_subs):
|
||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
std = tex_slat_normalization["std"]
|
std = tex_slat_normalization["std"].to(samples)
|
||||||
mean = tex_slat_normalization["mean"]
|
mean = tex_slat_normalization["mean"].to(samples)
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||||
@ -239,6 +241,8 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
scale = min(1, 1024 / max_size)
|
scale = min(1, 1024 / max_size)
|
||||||
if scale < 1:
|
if scale < 1:
|
||||||
image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS)
|
image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS)
|
||||||
|
new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale)
|
||||||
|
mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0)
|
||||||
|
|
||||||
image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255
|
image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255
|
||||||
|
|
||||||
@ -510,6 +514,7 @@ class PostProcessMesh(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
||||||
|
mesh = copy.deepcopy(mesh)
|
||||||
verts, faces = mesh.vertices, mesh.faces
|
verts, faces = mesh.vertices, mesh.faces
|
||||||
|
|
||||||
if fill_holes_perimeter != 0.0:
|
if fill_holes_perimeter != 0.0:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user