mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
fixes
This commit is contained in:
parent
1fde60b2bc
commit
253ee4c02c
@ -44,9 +44,6 @@ class DINOv3ViTAttention(nn.Module):
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
||||
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
@ -251,7 +248,7 @@ class DINOv3ViTModel(nn.Module):
|
||||
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_hidden_layers)])
|
||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
import torch
|
||||
from typing import Dict, Callable
|
||||
|
||||
NO_TRITION = False
|
||||
NO_TRITON = False
|
||||
try:
|
||||
allow_tf32 = torch.cuda.is_tf32_supported()
|
||||
except Exception:
|
||||
@ -115,8 +115,8 @@ try:
|
||||
allow_tf32=allow_tf32,
|
||||
)
|
||||
return output
|
||||
except:
|
||||
NO_TRITION = True
|
||||
except Exception:
|
||||
NO_TRITON = True
|
||||
|
||||
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
||||
# offsets in same order as CUDA kernel
|
||||
@ -364,6 +364,8 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||
|
||||
|
||||
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
||||
if NO_TRITON: # TODO
|
||||
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
|
||||
if len(shape) == 5:
|
||||
N, C, W, H, D = shape
|
||||
else:
|
||||
|
||||
@ -697,8 +697,6 @@ class SparseStructureFlowModel(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
||||
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
|
||||
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
|
||||
|
||||
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
|
||||
|
||||
@ -746,7 +744,8 @@ class Trellis2(nn.Module):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
# for some reason it passes num_heads = -1
|
||||
num_heads = 12
|
||||
if num_heads == -1:
|
||||
num_heads = 12
|
||||
args = {
|
||||
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||
@ -763,8 +762,10 @@ class Trellis2(nn.Module):
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
# FIXME: should find a way to distinguish between 512/1024 models
|
||||
# currently assumes 1024
|
||||
transformer_options = kwargs.get("transformer_options")
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
embeds = kwargs.get("embeds")
|
||||
if embeds is None:
|
||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||
#_, cond = context.chunk(2) # TODO
|
||||
cond = embeds.chunk(2)[0]
|
||||
context = torch.cat([torch.zeros_like(cond), cond])
|
||||
@ -807,6 +808,8 @@ class Trellis2(nn.Module):
|
||||
# TODO
|
||||
out = self.img2shape(x, timestep, context)
|
||||
elif mode == "texture_generation":
|
||||
if self.shape2txt is None:
|
||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||
else: # structure
|
||||
timestep = timestep_reshift(timestep)
|
||||
|
||||
@ -1522,6 +1522,8 @@ class Vae(nn.Module):
|
||||
return self.shape_dec(slat, return_subs=True)
|
||||
|
||||
def decode_tex_slat(self, slat, subs):
|
||||
if self.txt_dec is None:
|
||||
raise ValueError("Checkpoint doesn't include texture model")
|
||||
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
import torch
|
||||
import torch.nn.functional as TF
|
||||
import comfy.model_management
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import copy
|
||||
|
||||
shape_slat_normalization = {
|
||||
"mean": torch.tensor([
|
||||
@ -145,11 +147,11 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
def execute(cls, samples, vae, resolution):
|
||||
vae = vae.first_stage_model
|
||||
samples = samples["samples"]
|
||||
std = shape_slat_normalization["std"]
|
||||
mean = shape_slat_normalization["mean"]
|
||||
std = shape_slat_normalization["std"].to(samples)
|
||||
mean = shape_slat_normalization["mean"].to(samples)
|
||||
samples = samples * std + mean
|
||||
|
||||
mesh, subs = vae.decode_shape_slat(resolution, samples)
|
||||
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
||||
return IO.NodeOutput(mesh, subs)
|
||||
|
||||
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
@ -172,8 +174,8 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
def execute(cls, samples, vae, shape_subs):
|
||||
vae = vae.first_stage_model
|
||||
samples = samples["samples"]
|
||||
std = tex_slat_normalization["std"]
|
||||
mean = tex_slat_normalization["mean"]
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
mean = tex_slat_normalization["mean"].to(samples)
|
||||
samples = samples * std + mean
|
||||
|
||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||
@ -239,6 +241,8 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
scale = min(1, 1024 / max_size)
|
||||
if scale < 1:
|
||||
image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS)
|
||||
new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale)
|
||||
mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0)
|
||||
|
||||
image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255
|
||||
|
||||
@ -510,6 +514,7 @@ class PostProcessMesh(IO.ComfyNode):
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
||||
mesh = copy.deepcopy(mesh)
|
||||
verts, faces = mesh.vertices, mesh.faces
|
||||
|
||||
if fill_holes_perimeter != 0.0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user