mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
debugging
This commit is contained in:
parent
f4059c189e
commit
b7764479c2
@ -14,6 +14,7 @@ except:
|
|||||||
def scaled_dot_product_attention(*args, **kwargs):
|
def scaled_dot_product_attention(*args, **kwargs):
|
||||||
num_all_args = len(args) + len(kwargs)
|
num_all_args = len(args) + len(kwargs)
|
||||||
|
|
||||||
|
q = None
|
||||||
if num_all_args == 1:
|
if num_all_args == 1:
|
||||||
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||||
|
|
||||||
@ -26,8 +27,10 @@ def scaled_dot_product_attention(*args, **kwargs):
|
|||||||
k = args[1] if len(args) > 1 else kwargs['k']
|
k = args[1] if len(args) > 1 else kwargs['k']
|
||||||
v = args[2] if len(args) > 2 else kwargs['v']
|
v = args[2] if len(args) > 2 else kwargs['v']
|
||||||
|
|
||||||
# TODO verify
|
if q is not None:
|
||||||
heads = q or qkv
|
heads = q
|
||||||
|
else:
|
||||||
|
heads = qkv
|
||||||
heads = heads.shape[2]
|
heads = heads.shape[2]
|
||||||
|
|
||||||
if optimized_attention.__name__ == 'attention_xformers':
|
if optimized_attention.__name__ == 'attention_xformers':
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from comfy.ldm.trellis2.attention import (
|
|||||||
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
||||||
)
|
)
|
||||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||||
from comfy.nested_tensor import NestedTensor
|
|
||||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||||
|
|
||||||
class SparseGELU(nn.GELU):
|
class SparseGELU(nn.GELU):
|
||||||
@ -586,6 +585,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
Lkv = context.shape[1]
|
Lkv = context.shape[1]
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
|
context = context.to(next(self.to_kv.parameters()).dtype)
|
||||||
kv = self.to_kv(context)
|
kv = self.to_kv(context)
|
||||||
q = q.reshape(B, L, self.num_heads, -1)
|
q = q.reshape(B, L, self.num_heads, -1)
|
||||||
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
||||||
@ -782,6 +782,7 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
h = block(h, t_emb, cond, self.rope_phases)
|
h = block(h, t_emb, cond, self.rope_phases)
|
||||||
h = manual_cast(h, x.dtype)
|
h = manual_cast(h, x.dtype)
|
||||||
h = F.layer_norm(h, h.shape[-1:])
|
h = F.layer_norm(h, h.shape[-1:])
|
||||||
|
h = h.to(next(self.out_layer.parameters()).dtype)
|
||||||
h = self.out_layer(h)
|
h = self.out_layer(h)
|
||||||
|
|
||||||
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
|
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
|
||||||
@ -823,9 +824,7 @@ class Trellis2(nn.Module):
|
|||||||
args.pop("out_channels")
|
args.pop("out_channels")
|
||||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||||
|
|
||||||
def forward(self, x: NestedTensor, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
if isinstance(x, NestedTensor):
|
|
||||||
x = x.tensors[0]
|
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
if not hasattr(x, "feats"):
|
if not hasattr(x, "feats"):
|
||||||
mode = "structure_generation"
|
mode = "structure_generation"
|
||||||
@ -843,6 +842,5 @@ class Trellis2(nn.Module):
|
|||||||
timestep = timestep_reshift(timestep)
|
timestep = timestep_reshift(timestep)
|
||||||
out = self.structure_model(x, timestep, context)
|
out = self.structure_model(x, timestep, context)
|
||||||
|
|
||||||
out = NestedTensor([out])
|
|
||||||
out.generation_mode = mode
|
out.generation_mode = mode
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -10,9 +10,6 @@ from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_
|
|||||||
|
|
||||||
|
|
||||||
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
||||||
"""
|
|
||||||
3D pixel shuffle.
|
|
||||||
"""
|
|
||||||
B, C, H, W, D = x.shape
|
B, C, H, W, D = x.shape
|
||||||
C_ = C // scale_factor**3
|
C_ = C // scale_factor**3
|
||||||
x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
|
x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
|
||||||
@ -967,6 +964,25 @@ class SparseLinear(nn.Linear):
|
|||||||
return input.replace(super().forward(input.feats))
|
return input.replace(super().forward(input.feats))
|
||||||
|
|
||||||
|
|
||||||
|
MIX_PRECISION_MODULES = (
|
||||||
|
nn.Conv1d,
|
||||||
|
nn.Conv2d,
|
||||||
|
nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d,
|
||||||
|
nn.ConvTranspose2d,
|
||||||
|
nn.ConvTranspose3d,
|
||||||
|
nn.Linear,
|
||||||
|
SparseConv3d,
|
||||||
|
SparseLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_module_to_f16(l):
|
||||||
|
if isinstance(l, MIX_PRECISION_MODULES):
|
||||||
|
for p in l.parameters():
|
||||||
|
p.data = p.data.half()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SparseUnetVaeEncoder(nn.Module):
|
class SparseUnetVaeEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -1381,8 +1397,12 @@ class ResBlock3d(nn.Module):
|
|||||||
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
self.norm1 = self.norm1.to(torch.float32)
|
||||||
|
self.norm2 = self.norm2.to(torch.float32)
|
||||||
h = self.norm1(x)
|
h = self.norm1(x)
|
||||||
h = F.silu(h)
|
h = F.silu(h)
|
||||||
|
dtype = next(self.conv1.parameters()).dtype
|
||||||
|
h = h.to(dtype)
|
||||||
h = self.conv1(h)
|
h = self.conv1(h)
|
||||||
h = self.norm2(h)
|
h = self.norm2(h)
|
||||||
h = F.silu(h)
|
h = F.silu(h)
|
||||||
@ -1400,7 +1420,7 @@ class SparseStructureDecoder(nn.Module):
|
|||||||
channels: List[int],
|
channels: List[int],
|
||||||
num_res_blocks_middle: int = 2,
|
num_res_blocks_middle: int = 2,
|
||||||
norm_type = "layer",
|
norm_type = "layer",
|
||||||
use_fp16: bool = False,
|
use_fp16: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
@ -1439,20 +1459,27 @@ class SparseStructureDecoder(nn.Module):
|
|||||||
if use_fp16:
|
if use_fp16:
|
||||||
self.convert_to_fp16()
|
self.convert_to_fp16()
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def convert_to_fp16(self) -> None:
|
||||||
|
self.use_fp16 = True
|
||||||
|
self.dtype = torch.float16
|
||||||
|
self.blocks.apply(convert_module_to_f16)
|
||||||
|
self.middle_block.apply(convert_module_to_f16)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
dtype = next(self.input_layer.parameters()).dtype
|
||||||
|
x = x.to(dtype)
|
||||||
h = self.input_layer(x)
|
h = self.input_layer(x)
|
||||||
|
|
||||||
h = h.type(self.dtype)
|
h = h.type(self.dtype)
|
||||||
|
|
||||||
h = self.middle_block(h)
|
h = self.middle_block(h)
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
h = block(h)
|
h = block(h)
|
||||||
|
|
||||||
h = h.type(x.dtype)
|
h = h.to(torch.float32)
|
||||||
|
self.out_layer = self.out_layer.to(torch.float32)
|
||||||
h = self.out_layer(h)
|
h = self.out_layer(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|||||||
@ -497,6 +497,10 @@ class VAE:
|
|||||||
init_txt_model = False
|
init_txt_model = False
|
||||||
if "txt_dec.blocks.1.16.norm1.weight" in sd:
|
if "txt_dec.blocks.1.16.norm1.weight" in sd:
|
||||||
init_txt_model = True
|
init_txt_model = True
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
# TODO
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
|
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
|
||||||
elif "decoder.conv_in.weight" in sd:
|
elif "decoder.conv_in.weight" in sd:
|
||||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from comfy_api.latest import ComfyExtension, IO
|
|||||||
import torch
|
import torch
|
||||||
from comfy.ldm.trellis2.model import SparseTensor
|
from comfy.ldm.trellis2.model import SparseTensor
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.nested_tensor import NestedTensor
|
import comfy.model_patcher
|
||||||
from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode
|
from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode
|
||||||
|
|
||||||
shape_slat_normalization = {
|
shape_slat_normalization = {
|
||||||
@ -137,14 +137,15 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples: NestedTensor, vae, resolution):
|
def execute(cls, samples, vae, resolution):
|
||||||
samples = samples.tensors[0]
|
vae = vae.first_stage_model
|
||||||
|
samples = samples["samples"]
|
||||||
std = shape_slat_normalization["std"]
|
std = shape_slat_normalization["std"]
|
||||||
mean = shape_slat_normalization["mean"]
|
mean = shape_slat_normalization["mean"]
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
mesh, subs = vae.decode_shape_slat(resolution, samples)
|
mesh, subs = vae.decode_shape_slat(resolution, samples)
|
||||||
return mesh, subs
|
return IO.NodeOutput(mesh, subs)
|
||||||
|
|
||||||
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -164,13 +165,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae, shape_subs):
|
def execute(cls, samples, vae, shape_subs):
|
||||||
samples = samples.tensors[0]
|
vae = vae.first_stage_model
|
||||||
|
samples = samples["samples"]
|
||||||
std = tex_slat_normalization["std"]
|
std = tex_slat_normalization["std"]
|
||||||
mean = tex_slat_normalization["mean"]
|
mean = tex_slat_normalization["mean"]
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||||
return mesh
|
return IO.NodeOutput(mesh)
|
||||||
|
|
||||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -189,10 +191,19 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae):
|
def execute(cls, samples, vae):
|
||||||
|
vae = vae.first_stage_model
|
||||||
decoder = vae.struct_dec
|
decoder = vae.struct_dec
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
decoder = comfy.model_patcher.ModelPatcher(
|
||||||
|
decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device()
|
||||||
|
)
|
||||||
|
comfy.model_management.load_model_gpu(decoder)
|
||||||
|
decoder = decoder.model
|
||||||
|
samples = samples["samples"]
|
||||||
|
samples = samples.to(load_device)
|
||||||
decoded = decoder(samples)>0
|
decoded = decoder(samples)>0
|
||||||
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
|
||||||
return coords
|
return IO.NodeOutput(coords)
|
||||||
|
|
||||||
class Trellis2Conditioning(IO.ComfyNode):
|
class Trellis2Conditioning(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -240,7 +251,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
coords = structure_output # or structure_output.coords
|
coords = structure_output # or structure_output.coords
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||||
latent = NestedTensor([latent])
|
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
@ -262,7 +272,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
# TODO
|
# TODO
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
|
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||||
latent = NestedTensor([latent])
|
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
@ -283,7 +292,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
in_channels = 8
|
in_channels = 8
|
||||||
resolution = 16
|
resolution = 16
|
||||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||||
latent = NestedTensor([latent])
|
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
def simplify_fn(vertices, faces, target=100000):
|
def simplify_fn(vertices, faces, target=100000):
|
||||||
@ -469,7 +477,7 @@ class PostProcessMesh(IO.ComfyNode):
|
|||||||
mesh.vertices = verts
|
mesh.vertices = verts
|
||||||
mesh.faces = faces
|
mesh.faces = faces
|
||||||
|
|
||||||
return mesh
|
return IO.NodeOutput(mesh)
|
||||||
|
|
||||||
class Trellis2Extension(ComfyExtension):
|
class Trellis2Extension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user