mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
bunch of fixes
This commit is contained in:
parent
253ee4c02c
commit
c9f5c788a7
@ -2,6 +2,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||
|
||||
@ -225,7 +226,7 @@ class DINOv3ViTLayer(nn.Module):
|
||||
class DINOv3ViTModel(nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
if dtype == torch.float16:
|
||||
if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False):
|
||||
dtype = torch.bfloat16
|
||||
num_hidden_layers = config["num_hidden_layers"]
|
||||
hidden_size = config["hidden_size"]
|
||||
|
||||
@ -3,12 +3,56 @@ import math
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from typing import Tuple, Union, List
|
||||
from comfy.ldm.trellis2.vae import VarLenTensor
|
||||
import comfy.ops
|
||||
|
||||
|
||||
# replica of the seedvr2 code
|
||||
def var_attn_arg(kwargs):
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q)
|
||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = True
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
b = q.size(0)
|
||||
dim_head = q.shape[-1]
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
mask = None
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if var_length:
|
||||
return out.contiguous().transpose(1, 2).values()
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
FLASH_ATTN_3_AVA = True
|
||||
try:
|
||||
import flash_attn_interface as flash_attn_3 # noqa: F401
|
||||
except:
|
||||
FLASH_ATTN_3_AVA = False
|
||||
|
||||
# TODO repalce with optimized attention
|
||||
def scaled_dot_product_attention(*args, **kwargs):
|
||||
@ -40,18 +84,10 @@ def scaled_dot_product_attention(*args, **kwargs):
|
||||
k, v = kv.unbind(dim=2)
|
||||
#out = xops.memory_efficient_attention(q, k, v)
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA:
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
if num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_flash': # TODO
|
||||
if 'flash_attn_3' not in globals():
|
||||
import flash_attn_interface as flash_attn_3
|
||||
if num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
out = flash_attn_3.flash_attn_func(q, k, v)
|
||||
elif num_all_args == 3:
|
||||
out = flash_attn_3.flash_attn_func(q, k, v)
|
||||
elif optimized_attention.__name__ == 'attention_pytorch':
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
@ -238,24 +274,16 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
elif num_all_args == 3:
|
||||
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
elif optimized_attention.__name__ == 'flash_attn_3': # TODO
|
||||
if 'flash_attn_3' not in globals():
|
||||
import flash_attn_interface as flash_attn_3
|
||||
|
||||
elif optimized_attention.__name__ == "attention_pytorch":
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args in [2, 3]:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
cu_seqlens_kv = cu_seqlens_q.clone()
|
||||
max_q_seqlen = max_kv_seqlen = max(q_seqlen)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=1)
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
max_q_seqlen = max(q_seqlen)
|
||||
max_kv_seqlen = max(kv_seqlen)
|
||||
elif num_all_args == 3:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
max_q_seqlen = max(q_seqlen)
|
||||
max_kv_seqlen = max(kv_seqlen)
|
||||
out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen)
|
||||
out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
|
||||
if s is not None:
|
||||
return s.replace(out)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import Dict, Callable
|
||||
import logging
|
||||
|
||||
NO_TRITON = False
|
||||
try:
|
||||
@ -366,6 +367,10 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
||||
if NO_TRITON: # TODO
|
||||
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
|
||||
if feats.shape[0] == 0:
|
||||
logging.warning("Found feats to be empty!")
|
||||
Co = weight.shape[0]
|
||||
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
|
||||
if len(shape) == 5:
|
||||
N, C, W, H, D = shape
|
||||
else:
|
||||
@ -427,9 +432,11 @@ class Voxel:
|
||||
voxel_size: float,
|
||||
coords: torch.Tensor = None,
|
||||
attrs: torch.Tensor = None,
|
||||
layout: Dict = {},
|
||||
device: torch.device = 'cuda'
|
||||
layout = None,
|
||||
device: torch.device = None
|
||||
):
|
||||
if layout is None:
|
||||
layout = {}
|
||||
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
|
||||
self.voxel_size = voxel_size
|
||||
self.coords = coords
|
||||
|
||||
@ -630,7 +630,6 @@ class SparseStructureFlowModel(nn.Module):
|
||||
mlp_ratio: float = 4,
|
||||
pe_mode: Literal["ape", "rope"] = "rope",
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
dtype: str = 'float32',
|
||||
use_checkpoint: bool = False,
|
||||
share_mod: bool = False,
|
||||
initialization: str = 'vanilla',
|
||||
@ -638,6 +637,7 @@ class SparseStructureFlowModel(nn.Module):
|
||||
qk_rms_norm_cross: bool = False,
|
||||
operations=None,
|
||||
device = None,
|
||||
dtype = torch.float32,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1004,7 +1004,6 @@ class SparseUnetVaeEncoder(nn.Module):
|
||||
self.model_channels = model_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
|
||||
self.input_layer = SparseLinear(in_channels, model_channels[0])
|
||||
self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels)
|
||||
@ -1247,24 +1246,26 @@ def flexible_dual_grid_to_mesh(
|
||||
hashmap_builder=None, # optional callable for building/caching a TorchHashMap
|
||||
):
|
||||
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"):
|
||||
device = coords.device
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \
|
||||
or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device:
|
||||
flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
|
||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
|
||||
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
|
||||
], dtype=torch.int, device=coords.device).unsqueeze(0)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"):
|
||||
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"):
|
||||
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"):
|
||||
flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False)
|
||||
], dtype=torch.int, device=device).unsqueeze(0)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device:
|
||||
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device:
|
||||
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train") or flexible_dual_grid_to_mesh.quad_split_train.device != device:
|
||||
flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=device, requires_grad=False)
|
||||
|
||||
# AABB
|
||||
if isinstance(aabb, (list, tuple)):
|
||||
aabb = np.array(aabb)
|
||||
if isinstance(aabb, np.ndarray):
|
||||
aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device)
|
||||
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||
|
||||
# Voxel size
|
||||
if voxel_size is not None:
|
||||
|
||||
@ -276,12 +276,11 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
in_channels = 32
|
||||
latent = torch.randn(1, coords.shape[0], in_channels)
|
||||
model = model.clone()
|
||||
if "transformer_options" not in model.model_options:
|
||||
model.model_options = {}
|
||||
model.model_options = model.model_options.copy()
|
||||
if "transformer_options" in model.model_options:
|
||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||
else:
|
||||
model.model_options = model.model_options.copy()
|
||||
|
||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||
model.model_options["transformer_options"] = {}
|
||||
|
||||
model.model_options["transformer_options"]["coords"] = coords
|
||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||
@ -310,12 +309,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
in_channels = 32
|
||||
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
||||
model = model.clone()
|
||||
if "transformer_options" not in model.model_options:
|
||||
model.model_options = {}
|
||||
model.model_options = model.model_options.copy()
|
||||
if "transformer_options" in model.model_options:
|
||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||
else:
|
||||
model.model_options = model.model_options.copy()
|
||||
|
||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||
model.model_options["transformer_options"] = {}
|
||||
|
||||
model.model_options["transformer_options"]["coords"] = coords
|
||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user