mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 21:51:00 +08:00
bunch of fixes
This commit is contained in:
parent
49d1eab2a5
commit
3fcdac8c7b
@ -2,6 +2,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||||
|
|
||||||
@ -225,7 +226,7 @@ class DINOv3ViTLayer(nn.Module):
|
|||||||
class DINOv3ViTModel(nn.Module):
|
class DINOv3ViTModel(nn.Module):
|
||||||
def __init__(self, config, dtype, device, operations):
|
def __init__(self, config, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
num_hidden_layers = config["num_hidden_layers"]
|
num_hidden_layers = config["num_hidden_layers"]
|
||||||
hidden_size = config["hidden_size"]
|
hidden_size = config["hidden_size"]
|
||||||
|
|||||||
@ -3,12 +3,56 @@ import math
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from typing import Tuple, Union, List
|
from typing import Tuple, Union, List
|
||||||
from comfy.ldm.trellis2.vae import VarLenTensor
|
from comfy.ldm.trellis2.vae import VarLenTensor
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
|
# replica of the seedvr2 code
|
||||||
|
def var_attn_arg(kwargs):
|
||||||
|
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||||
|
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||||
|
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||||
|
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q)
|
||||||
|
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||||
|
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||||
|
|
||||||
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
var_length = True
|
||||||
|
if var_length:
|
||||||
|
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||||
|
if not skip_reshape:
|
||||||
|
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||||
|
total_tokens, embed_dim = q.shape
|
||||||
|
head_dim = embed_dim // heads
|
||||||
|
q = q.view(total_tokens, heads, head_dim)
|
||||||
|
k = k.view(k.shape[0], heads, head_dim)
|
||||||
|
v = v.view(v.shape[0], heads, head_dim)
|
||||||
|
|
||||||
|
b = q.size(0)
|
||||||
|
dim_head = q.shape[-1]
|
||||||
|
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||||
|
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||||
|
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
if var_length:
|
||||||
|
return out.contiguous().transpose(1, 2).values()
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
FLASH_ATTN_3_AVA = True
|
|
||||||
try:
|
|
||||||
import flash_attn_interface as flash_attn_3 # noqa: F401
|
|
||||||
except:
|
|
||||||
FLASH_ATTN_3_AVA = False
|
|
||||||
|
|
||||||
# TODO repalce with optimized attention
|
# TODO repalce with optimized attention
|
||||||
def scaled_dot_product_attention(*args, **kwargs):
|
def scaled_dot_product_attention(*args, **kwargs):
|
||||||
@ -40,18 +84,10 @@ def scaled_dot_product_attention(*args, **kwargs):
|
|||||||
k, v = kv.unbind(dim=2)
|
k, v = kv.unbind(dim=2)
|
||||||
#out = xops.memory_efficient_attention(q, k, v)
|
#out = xops.memory_efficient_attention(q, k, v)
|
||||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA:
|
elif optimized_attention.__name__ == 'attention_flash':
|
||||||
if num_all_args == 2:
|
if num_all_args == 2:
|
||||||
k, v = kv.unbind(dim=2)
|
k, v = kv.unbind(dim=2)
|
||||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
elif optimized_attention.__name__ == 'attention_flash': # TODO
|
|
||||||
if 'flash_attn_3' not in globals():
|
|
||||||
import flash_attn_interface as flash_attn_3
|
|
||||||
if num_all_args == 2:
|
|
||||||
k, v = kv.unbind(dim=2)
|
|
||||||
out = flash_attn_3.flash_attn_func(q, k, v)
|
|
||||||
elif num_all_args == 3:
|
|
||||||
out = flash_attn_3.flash_attn_func(q, k, v)
|
|
||||||
elif optimized_attention.__name__ == 'attention_pytorch':
|
elif optimized_attention.__name__ == 'attention_pytorch':
|
||||||
if num_all_args == 1:
|
if num_all_args == 1:
|
||||||
q, k, v = qkv.unbind(dim=2)
|
q, k, v = qkv.unbind(dim=2)
|
||||||
@ -238,24 +274,16 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
|
|||||||
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||||
elif num_all_args == 3:
|
elif num_all_args == 3:
|
||||||
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||||
elif optimized_attention.__name__ == 'flash_attn_3': # TODO
|
|
||||||
if 'flash_attn_3' not in globals():
|
elif optimized_attention.__name__ == "attention_pytorch":
|
||||||
import flash_attn_interface as flash_attn_3
|
|
||||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||||
|
if num_all_args in [2, 3]:
|
||||||
|
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||||
if num_all_args == 1:
|
if num_all_args == 1:
|
||||||
q, k, v = qkv.unbind(dim=1)
|
q, k, v = qkv.unbind(dim=1)
|
||||||
cu_seqlens_kv = cu_seqlens_q.clone()
|
|
||||||
max_q_seqlen = max_kv_seqlen = max(q_seqlen)
|
|
||||||
elif num_all_args == 2:
|
elif num_all_args == 2:
|
||||||
k, v = kv.unbind(dim=1)
|
k, v = kv.unbind(dim=1)
|
||||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||||
max_q_seqlen = max(q_seqlen)
|
|
||||||
max_kv_seqlen = max(kv_seqlen)
|
|
||||||
elif num_all_args == 3:
|
|
||||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
|
||||||
max_q_seqlen = max(q_seqlen)
|
|
||||||
max_kv_seqlen = max(kv_seqlen)
|
|
||||||
out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen)
|
|
||||||
|
|
||||||
if s is not None:
|
if s is not None:
|
||||||
return s.replace(out)
|
return s.replace(out)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Callable
|
from typing import Dict, Callable
|
||||||
|
import logging
|
||||||
|
|
||||||
NO_TRITON = False
|
NO_TRITON = False
|
||||||
try:
|
try:
|
||||||
@ -366,6 +367,10 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
|||||||
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
||||||
if NO_TRITON: # TODO
|
if NO_TRITON: # TODO
|
||||||
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
|
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
|
||||||
|
if feats.shape[0] == 0:
|
||||||
|
logging.warning("Found feats to be empty!")
|
||||||
|
Co = weight.shape[0]
|
||||||
|
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
|
||||||
if len(shape) == 5:
|
if len(shape) == 5:
|
||||||
N, C, W, H, D = shape
|
N, C, W, H, D = shape
|
||||||
else:
|
else:
|
||||||
@ -427,9 +432,11 @@ class Voxel:
|
|||||||
voxel_size: float,
|
voxel_size: float,
|
||||||
coords: torch.Tensor = None,
|
coords: torch.Tensor = None,
|
||||||
attrs: torch.Tensor = None,
|
attrs: torch.Tensor = None,
|
||||||
layout: Dict = {},
|
layout = None,
|
||||||
device: torch.device = 'cuda'
|
device: torch.device = None
|
||||||
):
|
):
|
||||||
|
if layout is None:
|
||||||
|
layout = {}
|
||||||
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
|
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
|
||||||
self.voxel_size = voxel_size
|
self.voxel_size = voxel_size
|
||||||
self.coords = coords
|
self.coords = coords
|
||||||
|
|||||||
@ -630,7 +630,6 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
mlp_ratio: float = 4,
|
mlp_ratio: float = 4,
|
||||||
pe_mode: Literal["ape", "rope"] = "rope",
|
pe_mode: Literal["ape", "rope"] = "rope",
|
||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||||
dtype: str = 'float32',
|
|
||||||
use_checkpoint: bool = False,
|
use_checkpoint: bool = False,
|
||||||
share_mod: bool = False,
|
share_mod: bool = False,
|
||||||
initialization: str = 'vanilla',
|
initialization: str = 'vanilla',
|
||||||
@ -638,6 +637,7 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
qk_rms_norm_cross: bool = False,
|
qk_rms_norm_cross: bool = False,
|
||||||
operations=None,
|
operations=None,
|
||||||
device = None,
|
device = None,
|
||||||
|
dtype = torch.float32,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1004,7 +1004,6 @@ class SparseUnetVaeEncoder(nn.Module):
|
|||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
|
||||||
|
|
||||||
self.input_layer = SparseLinear(in_channels, model_channels[0])
|
self.input_layer = SparseLinear(in_channels, model_channels[0])
|
||||||
self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels)
|
self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels)
|
||||||
@ -1247,24 +1246,26 @@ def flexible_dual_grid_to_mesh(
|
|||||||
hashmap_builder=None, # optional callable for building/caching a TorchHashMap
|
hashmap_builder=None, # optional callable for building/caching a TorchHashMap
|
||||||
):
|
):
|
||||||
|
|
||||||
if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"):
|
device = coords.device
|
||||||
|
if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \
|
||||||
|
or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device:
|
||||||
flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([
|
flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([
|
||||||
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
|
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
|
||||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
|
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
|
||||||
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
|
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
|
||||||
], dtype=torch.int, device=coords.device).unsqueeze(0)
|
], dtype=torch.int, device=device).unsqueeze(0)
|
||||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"):
|
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device:
|
||||||
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False)
|
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
||||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"):
|
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device:
|
||||||
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False)
|
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
||||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"):
|
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train") or flexible_dual_grid_to_mesh.quad_split_train.device != device:
|
||||||
flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False)
|
flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=device, requires_grad=False)
|
||||||
|
|
||||||
# AABB
|
# AABB
|
||||||
if isinstance(aabb, (list, tuple)):
|
if isinstance(aabb, (list, tuple)):
|
||||||
aabb = np.array(aabb)
|
aabb = np.array(aabb)
|
||||||
if isinstance(aabb, np.ndarray):
|
if isinstance(aabb, np.ndarray):
|
||||||
aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device)
|
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
# Voxel size
|
# Voxel size
|
||||||
if voxel_size is not None:
|
if voxel_size is not None:
|
||||||
|
|||||||
@ -276,12 +276,11 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = torch.randn(1, coords.shape[0], in_channels)
|
latent = torch.randn(1, coords.shape[0], in_channels)
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
if "transformer_options" not in model.model_options:
|
model.model_options = model.model_options.copy()
|
||||||
model.model_options = {}
|
if "transformer_options" in model.model_options:
|
||||||
|
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||||
else:
|
else:
|
||||||
model.model_options = model.model_options.copy()
|
model.model_options["transformer_options"] = {}
|
||||||
|
|
||||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
|
||||||
|
|
||||||
model.model_options["transformer_options"]["coords"] = coords
|
model.model_options["transformer_options"]["coords"] = coords
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||||
@ -310,12 +309,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
if "transformer_options" not in model.model_options:
|
model.model_options = model.model_options.copy()
|
||||||
model.model_options = {}
|
if "transformer_options" in model.model_options:
|
||||||
|
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||||
else:
|
else:
|
||||||
model.model_options = model.model_options.copy()
|
model.model_options["transformer_options"] = {}
|
||||||
|
|
||||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
|
||||||
|
|
||||||
model.model_options["transformer_options"]["coords"] = coords
|
model.model_options["transformer_options"]["coords"] = coords
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user