mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
debugging
This commit is contained in:
parent
f4059c189e
commit
b7764479c2
@ -14,6 +14,7 @@ except:
|
||||
def scaled_dot_product_attention(*args, **kwargs):
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
|
||||
q = None
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||
|
||||
@ -26,8 +27,10 @@ def scaled_dot_product_attention(*args, **kwargs):
|
||||
k = args[1] if len(args) > 1 else kwargs['k']
|
||||
v = args[2] if len(args) > 2 else kwargs['v']
|
||||
|
||||
# TODO verify
|
||||
heads = q or qkv
|
||||
if q is not None:
|
||||
heads = q
|
||||
else:
|
||||
heads = qkv
|
||||
heads = heads.shape[2]
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
|
||||
@ -7,7 +7,6 @@ from comfy.ldm.trellis2.attention import (
|
||||
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
||||
)
|
||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||
from comfy.nested_tensor import NestedTensor
|
||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||
|
||||
class SparseGELU(nn.GELU):
|
||||
@ -586,6 +585,7 @@ class MultiHeadAttention(nn.Module):
|
||||
else:
|
||||
Lkv = context.shape[1]
|
||||
q = self.to_q(x)
|
||||
context = context.to(next(self.to_kv.parameters()).dtype)
|
||||
kv = self.to_kv(context)
|
||||
q = q.reshape(B, L, self.num_heads, -1)
|
||||
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
||||
@ -782,6 +782,7 @@ class SparseStructureFlowModel(nn.Module):
|
||||
h = block(h, t_emb, cond, self.rope_phases)
|
||||
h = manual_cast(h, x.dtype)
|
||||
h = F.layer_norm(h, h.shape[-1:])
|
||||
h = h.to(next(self.out_layer.parameters()).dtype)
|
||||
h = self.out_layer(h)
|
||||
|
||||
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
|
||||
@ -823,9 +824,7 @@ class Trellis2(nn.Module):
|
||||
args.pop("out_channels")
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||
|
||||
def forward(self, x: NestedTensor, timestep, context, **kwargs):
|
||||
if isinstance(x, NestedTensor):
|
||||
x = x.tensors[0]
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
embeds = kwargs.get("embeds")
|
||||
if not hasattr(x, "feats"):
|
||||
mode = "structure_generation"
|
||||
@ -843,6 +842,5 @@ class Trellis2(nn.Module):
|
||||
timestep = timestep_reshift(timestep)
|
||||
out = self.structure_model(x, timestep, context)
|
||||
|
||||
out = NestedTensor([out])
|
||||
out.generation_mode = mode
|
||||
return out
|
||||
|
||||
@ -10,9 +10,6 @@ from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_
|
||||
|
||||
|
||||
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
||||
"""
|
||||
3D pixel shuffle.
|
||||
"""
|
||||
B, C, H, W, D = x.shape
|
||||
C_ = C // scale_factor**3
|
||||
x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
|
||||
@ -967,6 +964,25 @@ class SparseLinear(nn.Linear):
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
|
||||
MIX_PRECISION_MODULES = (
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nn.ConvTranspose1d,
|
||||
nn.ConvTranspose2d,
|
||||
nn.ConvTranspose3d,
|
||||
nn.Linear,
|
||||
SparseConv3d,
|
||||
SparseLinear,
|
||||
)
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
if isinstance(l, MIX_PRECISION_MODULES):
|
||||
for p in l.parameters():
|
||||
p.data = p.data.half()
|
||||
|
||||
|
||||
|
||||
class SparseUnetVaeEncoder(nn.Module):
|
||||
"""
|
||||
@ -1381,8 +1397,12 @@ class ResBlock3d(nn.Module):
|
||||
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self.norm1 = self.norm1.to(torch.float32)
|
||||
self.norm2 = self.norm2.to(torch.float32)
|
||||
h = self.norm1(x)
|
||||
h = F.silu(h)
|
||||
dtype = next(self.conv1.parameters()).dtype
|
||||
h = h.to(dtype)
|
||||
h = self.conv1(h)
|
||||
h = self.norm2(h)
|
||||
h = F.silu(h)
|
||||
@ -1400,7 +1420,7 @@ class SparseStructureDecoder(nn.Module):
|
||||
channels: List[int],
|
||||
num_res_blocks_middle: int = 2,
|
||||
norm_type = "layer",
|
||||
use_fp16: bool = False,
|
||||
use_fp16: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
@ -1439,20 +1459,27 @@ class SparseStructureDecoder(nn.Module):
|
||||
if use_fp16:
|
||||
self.convert_to_fp16()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def convert_to_fp16(self) -> None:
|
||||
self.use_fp16 = True
|
||||
self.dtype = torch.float16
|
||||
self.blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dtype = next(self.input_layer.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
h = self.input_layer(x)
|
||||
|
||||
h = h.type(self.dtype)
|
||||
|
||||
h = self.middle_block(h)
|
||||
for block in self.blocks:
|
||||
h = block(h)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.to(torch.float32)
|
||||
self.out_layer = self.out_layer.to(torch.float32)
|
||||
h = self.out_layer(h)
|
||||
return h
|
||||
|
||||
|
||||
@ -497,6 +497,10 @@ class VAE:
|
||||
init_txt_model = False
|
||||
if "txt_dec.blocks.1.16.norm1.weight" in sd:
|
||||
init_txt_model = True
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
|
||||
@ -3,7 +3,7 @@ from comfy_api.latest import ComfyExtension, IO
|
||||
import torch
|
||||
from comfy.ldm.trellis2.model import SparseTensor
|
||||
import comfy.model_management
|
||||
from comfy.nested_tensor import NestedTensor
|
||||
import comfy.model_patcher
|
||||
from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode
|
||||
|
||||
shape_slat_normalization = {
|
||||
@ -137,14 +137,15 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples: NestedTensor, vae, resolution):
|
||||
samples = samples.tensors[0]
|
||||
def execute(cls, samples, vae, resolution):
|
||||
vae = vae.first_stage_model
|
||||
samples = samples["samples"]
|
||||
std = shape_slat_normalization["std"]
|
||||
mean = shape_slat_normalization["mean"]
|
||||
samples = samples * std + mean
|
||||
|
||||
mesh, subs = vae.decode_shape_slat(resolution, samples)
|
||||
return mesh, subs
|
||||
return IO.NodeOutput(mesh, subs)
|
||||
|
||||
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -164,13 +165,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, shape_subs):
|
||||
samples = samples.tensors[0]
|
||||
vae = vae.first_stage_model
|
||||
samples = samples["samples"]
|
||||
std = tex_slat_normalization["std"]
|
||||
mean = tex_slat_normalization["mean"]
|
||||
samples = samples * std + mean
|
||||
|
||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||
return mesh
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -189,10 +191,19 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae):
|
||||
vae = vae.first_stage_model
|
||||
decoder = vae.struct_dec
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
decoder = comfy.model_patcher.ModelPatcher(
|
||||
decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device()
|
||||
)
|
||||
comfy.model_management.load_model_gpu(decoder)
|
||||
decoder = decoder.model
|
||||
samples = samples["samples"]
|
||||
samples = samples.to(load_device)
|
||||
decoded = decoder(samples)>0
|
||||
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
|
||||
return coords
|
||||
return IO.NodeOutput(coords)
|
||||
|
||||
class Trellis2Conditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -240,7 +251,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
coords = structure_output # or structure_output.coords
|
||||
in_channels = 32
|
||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||
latent = NestedTensor([latent])
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
@ -262,7 +272,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
# TODO
|
||||
in_channels = 32
|
||||
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||
latent = NestedTensor([latent])
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
@ -283,7 +292,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
in_channels = 8
|
||||
resolution = 16
|
||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||
latent = NestedTensor([latent])
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
def simplify_fn(vertices, faces, target=100000):
|
||||
@ -469,7 +477,7 @@ class PostProcessMesh(IO.ComfyNode):
|
||||
mesh.vertices = verts
|
||||
mesh.faces = faces
|
||||
|
||||
return mesh
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
class Trellis2Extension(ComfyExtension):
|
||||
@override
|
||||
|
||||
Loading…
Reference in New Issue
Block a user