bunch of fixes

This commit is contained in:
Yousef Rafat 2026-02-22 23:47:49 +02:00
parent 253ee4c02c
commit c9f5c788a7
6 changed files with 86 additions and 51 deletions

View File

@ -2,6 +2,7 @@ import math
import torch
import torch.nn as nn
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
@ -225,7 +226,7 @@ class DINOv3ViTLayer(nn.Module):
class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
if dtype == torch.float16:
if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False):
dtype = torch.bfloat16
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]

View File

@ -3,12 +3,56 @@ import math
from comfy.ldm.modules.attention import optimized_attention
from typing import Tuple, Union, List
from comfy.ldm.trellis2.vae import VarLenTensor
import comfy.ops
# replica of the seedvr2 code
def var_attn_arg(kwargs):
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q)
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q)
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = True
if var_length:
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
b = q.size(0)
dim_head = q.shape[-1]
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
mask = None
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if mask is not None:
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if var_length:
return out.contiguous().transpose(1, 2).values()
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
FLASH_ATTN_3_AVA = True
try:
import flash_attn_interface as flash_attn_3 # noqa: F401
except:
FLASH_ATTN_3_AVA = False
# TODO repalce with optimized attention
def scaled_dot_product_attention(*args, **kwargs):
@ -40,18 +84,10 @@ def scaled_dot_product_attention(*args, **kwargs):
k, v = kv.unbind(dim=2)
#out = xops.memory_efficient_attention(q, k, v)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA:
elif optimized_attention.__name__ == 'attention_flash':
if num_all_args == 2:
k, v = kv.unbind(dim=2)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
elif optimized_attention.__name__ == 'attention_flash': # TODO
if 'flash_attn_3' not in globals():
import flash_attn_interface as flash_attn_3
if num_all_args == 2:
k, v = kv.unbind(dim=2)
out = flash_attn_3.flash_attn_func(q, k, v)
elif num_all_args == 3:
out = flash_attn_3.flash_attn_func(q, k, v)
elif optimized_attention.__name__ == 'attention_pytorch':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
@ -238,24 +274,16 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif num_all_args == 3:
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif optimized_attention.__name__ == 'flash_attn_3': # TODO
if 'flash_attn_3' not in globals():
import flash_attn_interface as flash_attn_3
elif optimized_attention.__name__ == "attention_pytorch":
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args in [2, 3]:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
cu_seqlens_kv = cu_seqlens_q.clone()
max_q_seqlen = max_kv_seqlen = max(q_seqlen)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
max_q_seqlen = max(q_seqlen)
max_kv_seqlen = max(kv_seqlen)
elif num_all_args == 3:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
max_q_seqlen = max(q_seqlen)
max_kv_seqlen = max(kv_seqlen)
out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen)
out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
if s is not None:
return s.replace(out)

View File

@ -3,6 +3,7 @@
import math
import torch
from typing import Dict, Callable
import logging
NO_TRITON = False
try:
@ -366,6 +367,10 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
if NO_TRITON: # TODO
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
if feats.shape[0] == 0:
logging.warning("Found feats to be empty!")
Co = weight.shape[0]
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
if len(shape) == 5:
N, C, W, H, D = shape
else:
@ -427,9 +432,11 @@ class Voxel:
voxel_size: float,
coords: torch.Tensor = None,
attrs: torch.Tensor = None,
layout: Dict = {},
device: torch.device = 'cuda'
layout = None,
device: torch.device = None
):
if layout is None:
layout = {}
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
self.voxel_size = voxel_size
self.coords = coords

View File

@ -630,7 +630,6 @@ class SparseStructureFlowModel(nn.Module):
mlp_ratio: float = 4,
pe_mode: Literal["ape", "rope"] = "rope",
rope_freq: Tuple[float, float] = (1.0, 10000.0),
dtype: str = 'float32',
use_checkpoint: bool = False,
share_mod: bool = False,
initialization: str = 'vanilla',
@ -638,6 +637,7 @@ class SparseStructureFlowModel(nn.Module):
qk_rms_norm_cross: bool = False,
operations=None,
device = None,
dtype = torch.float32,
**kwargs
):
super().__init__()

View File

@ -1004,7 +1004,6 @@ class SparseUnetVaeEncoder(nn.Module):
self.model_channels = model_channels
self.num_blocks = num_blocks
self.dtype = torch.float16 if use_fp16 else torch.float32
self.dtype = torch.float16 if use_fp16 else torch.float32
self.input_layer = SparseLinear(in_channels, model_channels[0])
self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels)
@ -1247,24 +1246,26 @@ def flexible_dual_grid_to_mesh(
hashmap_builder=None, # optional callable for building/caching a TorchHashMap
):
if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"):
device = coords.device
if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \
or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device:
flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
], dtype=torch.int, device=coords.device).unsqueeze(0)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"):
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"):
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"):
flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False)
], dtype=torch.int, device=device).unsqueeze(0)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device:
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device:
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train") or flexible_dual_grid_to_mesh.quad_split_train.device != device:
flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=device, requires_grad=False)
# AABB
if isinstance(aabb, (list, tuple)):
aabb = np.array(aabb)
if isinstance(aabb, np.ndarray):
aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device)
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
# Voxel size
if voxel_size is not None:

View File

@ -276,12 +276,11 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
in_channels = 32
latent = torch.randn(1, coords.shape[0], in_channels)
model = model.clone()
if "transformer_options" not in model.model_options:
model.model_options = {}
model.model_options = model.model_options.copy()
if "transformer_options" in model.model_options:
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
else:
model.model_options = model.model_options.copy()
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
model.model_options["transformer_options"] = {}
model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
@ -310,12 +309,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
in_channels = 32
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
model = model.clone()
if "transformer_options" not in model.model_options:
model.model_options = {}
model.model_options = model.model_options.copy()
if "transformer_options" in model.model_options:
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
else:
model.model_options = model.model_options.copy()
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
model.model_options["transformer_options"] = {}
model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"