progressing

This commit is contained in:
Yousef Rafat 2026-02-23 01:40:56 +02:00
parent 30db88dd64
commit 2a27c3b417
4 changed files with 33 additions and 10 deletions

View File

@ -226,8 +226,11 @@ class DINOv3ViTLayer(nn.Module):
class DINOv3ViTModel(nn.Module): class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations): def __init__(self, config, dtype, device, operations):
super().__init__() super().__init__()
if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False): use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True)
if dtype == torch.float16 and use_bf16:
dtype = torch.bfloat16 dtype = torch.bfloat16
elif dtype == torch.float16 and not use_bf16:
dtype = torch.float32
num_hidden_layers = config["num_hidden_layers"] num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"] hidden_size = config["hidden_size"]
num_attention_heads = config["num_attention_heads"] num_attention_heads = config["num_attention_heads"]

View File

@ -10,8 +10,8 @@ import comfy.ops
def var_attn_arg(kwargs): def var_attn_arg(kwargs):
cu_seqlens_q = kwargs.get("cu_seqlens_q", None) cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
max_seqlen_q = kwargs.get("max_seqlen_q", None) max_seqlen_q = kwargs.get("max_seqlen_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q) cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q) max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q)
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
@ -183,6 +183,7 @@ def calc_window_partition(
def sparse_scaled_dot_product_attention(*args, **kwargs): def sparse_scaled_dot_product_attention(*args, **kwargs):
q=None
arg_names_dict = { arg_names_dict = {
1: ['qkv'], 1: ['qkv'],
2: ['q', 'kv'], 2: ['q', 'kv'],
@ -250,6 +251,12 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
v = v.reshape(N * L, H, CO) # [T_KV, H, Co] v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
# TODO: change
if q is not None:
heads = q
else:
heads = qkv
heads = heads.shape[2]
if optimized_attention.__name__ == 'attention_xformers': if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals(): if 'xops' not in globals():
import xformers.ops as xops import xformers.ops as xops
@ -279,11 +286,15 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args in [2, 3]: if num_all_args in [2, 3]:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
else:
cu_seqlens_kv = cu_seqlens_q
if num_all_args == 1: if num_all_args == 1:
q, k, v = qkv.unbind(dim=1) q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2: elif num_all_args == 2:
k, v = kv.unbind(dim=1) k, v = kv.unbind(dim=1)
out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen),
skip_reshape=True, skip_output_reshape=True)
if s is not None: if s is not None:
return s.replace(out) return s.replace(out)

View File

@ -232,6 +232,8 @@ class SparseMultiHeadAttention(nn.Module):
else: else:
q = self._linear(self.to_q, x) q = self._linear(self.to_q, x)
q = self._reshape_chs(q, (self.num_heads, -1)) q = self._reshape_chs(q, (self.num_heads, -1))
dtype = next(self.to_kv.parameters()).dtype
context = context.to(dtype)
kv = self._linear(self.to_kv, context) kv = self._linear(self.to_kv, context)
kv = self._fused_pre(kv, num_fused=2) kv = self._fused_pre(kv, num_fused=2)
if self.qk_rms_norm: if self.qk_rms_norm:
@ -760,15 +762,13 @@ class Trellis2(nn.Module):
self.guidance_interval_txt = [0.6, 0.9] self.guidance_interval_txt = [0.6, 0.9]
def forward(self, x, timestep, context, **kwargs): def forward(self, x, timestep, context, **kwargs):
# FIXME: should find a way to distinguish between 512/1024 models
# currently assumes 1024
transformer_options = kwargs.get("transformer_options", {}) transformer_options = kwargs.get("transformer_options", {})
embeds = kwargs.get("embeds") embeds = kwargs.get("embeds")
if embeds is None: if embeds is None:
raise ValueError("Trellis2.forward requires 'embeds' in kwargs") raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
#_, cond = context.chunk(2) # TODO is_1024 = self.img2shape.resolution == 1024
cond = embeds.chunk(2)[0] if is_1024:
context = torch.cat([torch.zeros_like(cond), cond]) context = embeds
coords = transformer_options.get("coords", None) coords = transformer_options.get("coords", None)
mode = transformer_options.get("generation_mode", "structure_generation") mode = transformer_options.get("generation_mode", "structure_generation")
if coords is not None: if coords is not None:

View File

@ -2,6 +2,7 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest import ComfyExtension, IO, Types
import torch.nn.functional as TF import torch.nn.functional as TF
import comfy.model_management import comfy.model_management
from comfy.utils import ProgressBar
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import torch import torch
@ -250,7 +251,7 @@ class Trellis2Conditioning(IO.ComfyNode):
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
embeds = conditioning["cond_1024"] # should add that embeds = conditioning["cond_1024"] # should add that
positive = [[conditioning["cond_512"], {"embeds": embeds}]] positive = [[conditioning["cond_512"], {"embeds": embeds}]]
negative = [[conditioning["neg_cond"], {"embeds": embeds}]] negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]]
return IO.NodeOutput(positive, negative) return IO.NodeOutput(positive, negative)
class EmptyShapeLatentTrellis2(IO.ComfyNode): class EmptyShapeLatentTrellis2(IO.ComfyNode):
@ -512,15 +513,23 @@ class PostProcessMesh(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, mesh, simplify, fill_holes_perimeter): def execute(cls, mesh, simplify, fill_holes_perimeter):
bar = ProgressBar(2)
mesh = copy.deepcopy(mesh) mesh = copy.deepcopy(mesh)
verts, faces = mesh.vertices, mesh.faces verts, faces = mesh.vertices, mesh.faces
if fill_holes_perimeter != 0.0: if fill_holes_perimeter != 0.0:
verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter)
bar.update(1)
else:
bar.update(1)
if simplify != 0: if simplify != 0:
verts, faces = simplify_fn(verts, faces, simplify) verts, faces = simplify_fn(verts, faces, simplify)
bar.update(1)
else:
bar.update(1)
# potentially adding laplacian smoothing
mesh.vertices = verts mesh.vertices = verts
mesh.faces = faces mesh.faces = faces