mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-05 06:01:39 +08:00
Refactor attention, optimize and heavily cleanup model code and cond paths
This commit is contained in:
parent
b2abca0f33
commit
4585a731c1
@ -32,6 +32,11 @@ except ImportError as e:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
try:
|
||||
from sageattention import sageattn_varlen # sageattention >= 2
|
||||
except ImportError:
|
||||
sageattn_varlen = None
|
||||
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
@ -48,6 +53,24 @@ except ImportError:
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
try:
|
||||
from torch.nn.attention.varlen import varlen_attn as _torch_varlen_attn
|
||||
except ImportError:
|
||||
_torch_varlen_attn = None
|
||||
|
||||
|
||||
def _is_varlen(kwargs):
|
||||
"""Varlen mode is opted into by passing cu_seqlens_q in kwargs."""
|
||||
return kwargs.get("cu_seqlens_q") is not None
|
||||
|
||||
|
||||
def _varlen_args(kwargs):
|
||||
cu_seqlens_q = kwargs["cu_seqlens_q"]
|
||||
cu_seqlens_kv = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||
max_seqlen_q = int(kwargs["max_seqlen_q"])
|
||||
max_seqlen_kv = int(kwargs.get("max_seqlen_kv", max_seqlen_q))
|
||||
return cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
def register_attention_function(name: str, func: Callable):
|
||||
# avoid replacing existing functions
|
||||
@ -144,6 +167,8 @@ def wrap_attn(func):
|
||||
|
||||
@wrap_attn
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if _is_varlen(kwargs):
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@ -218,6 +243,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
@wrap_attn
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if _is_varlen(kwargs):
|
||||
return attention_pytorch(query, key, value, heads, mask=mask, **kwargs)
|
||||
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@ -293,6 +320,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
|
||||
@wrap_attn
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if _is_varlen(kwargs):
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@ -424,6 +453,17 @@ except:
|
||||
|
||||
@wrap_attn
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if _is_varlen(kwargs):
|
||||
# q, k, v expected as packed [T_total, H, C]. xformers wants per-item
|
||||
# seqlen lists and a leading batch dim of 1.
|
||||
cu_seqlens_q, cu_seqlens_kv, _max_q, _max_kv = _varlen_args(kwargs)
|
||||
q_seqlen = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist()
|
||||
kv_seqlen = (cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]).tolist()
|
||||
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
|
||||
return xformers.ops.memory_efficient_attention(
|
||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_bias=attn_bias,
|
||||
)[0]
|
||||
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
@ -493,6 +533,22 @@ else:
|
||||
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if _is_varlen(kwargs):
|
||||
# q, k, v expected as packed [T_total, H, C]. cu_seqlens_q / cu_seqlens_kv
|
||||
# describe per-item offsets. Native varlen kernel if available, else
|
||||
# nested-tensor SDPA fallback. mask/attn_precision/etc are ignored here.
|
||||
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
|
||||
if _torch_varlen_attn is not None:
|
||||
return _torch_varlen_attn(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
|
||||
q_nj = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k_nj = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_kv.long())
|
||||
v_nj = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_kv.long())
|
||||
out = comfy.ops.scaled_dot_product_attention(
|
||||
q_nj.transpose(1, 2), k_nj.transpose(1, 2), v_nj.transpose(1, 2),
|
||||
attn_mask=None, dropout_p=0.0, is_causal=False,
|
||||
)
|
||||
return out.transpose(1, 2).values()
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
@ -541,6 +597,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if _is_varlen(kwargs):
|
||||
# q, k, v expected as packed [T_total, H, C].
|
||||
if sageattn_varlen is None:
|
||||
# sageattention v1 has no varlen kernel; fall back to attention_pytorch's varlen path.
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
|
||||
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
|
||||
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
@ -698,6 +761,12 @@ except AttributeError as error:
|
||||
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if _is_varlen(kwargs):
|
||||
# q, k, v expected as packed [T_total, H, C].
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
|
||||
return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
|
||||
@ -1,125 +1,107 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from typing import Tuple, Union, List
|
||||
from comfy.ldm.trellis2.vae import VarLenTensor
|
||||
import comfy.ops
|
||||
|
||||
try:
|
||||
from torch.nn.attention.varlen import varlen_attn as _varlen_attn
|
||||
except ImportError:
|
||||
_varlen_attn = None
|
||||
|
||||
|
||||
def var_attn_arg(kwargs):
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||
max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q)
|
||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
def _build_cu_seqlens(seqlen, device):
|
||||
"""Cumulative-offset tensor from a list/tensor of per-item sequence lengths."""
|
||||
if isinstance(seqlen, torch.Tensor):
|
||||
return torch.cat(
|
||||
[torch.zeros(1, dtype=torch.int32, device=seqlen.device), torch.cumsum(seqlen, dim=0).int()],
|
||||
dim=0,
|
||||
).to(device)
|
||||
return torch.cat(
|
||||
[torch.tensor([0]), torch.cumsum(torch.tensor(seqlen), dim=0)],
|
||||
).int().to(device)
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
if _varlen_attn is not None:
|
||||
return _varlen_attn(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
int(max_seqlen_q), int(max_seqlen_k),
|
||||
)
|
||||
def _layout_to_feats(t):
|
||||
"""Returns (sparse_template_or_None, seqlen, feats[T, H, C])."""
|
||||
if isinstance(t, VarLenTensor):
|
||||
seqlen = [t.layout[i].stop - t.layout[i].start for i in range(t.shape[0])]
|
||||
return t, seqlen, t.feats
|
||||
N, L = t.shape[:2]
|
||||
return None, [L] * N, t.reshape(-1, *t.shape[-2:])
|
||||
|
||||
# Fallback: nested-tensor SDPA (PyTorch < the version that introduced varlen_attn)
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
return out.transpose(1, 2).values()
|
||||
|
||||
def scaled_dot_product_attention(*args, **kwargs):
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
|
||||
q = None
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs.get('qkv')
|
||||
elif num_all_args == 2:
|
||||
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||
kv = args[1] if len(args) > 1 else kwargs.get('kv')
|
||||
elif num_all_args == 3:
|
||||
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||
k = args[1] if len(args) > 1 else kwargs.get('k')
|
||||
v = args[2] if len(args) > 2 else kwargs.get('v')
|
||||
|
||||
if q is not None:
|
||||
heads = q.shape[2]
|
||||
else:
|
||||
heads = qkv.shape[3]
|
||||
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
|
||||
def dense_attention(q, k, v, **kwargs):
|
||||
"""q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention."""
|
||||
heads = q.shape[2]
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
|
||||
return out.permute(0, 2, 1, 3)
|
||||
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
return out
|
||||
def sparse_attention(q, k, v, **kwargs):
|
||||
"""
|
||||
Varlen attention for SparseTensor inputs. Each of q, k, v may be a VarLenTensor
|
||||
(sparse) or dense [B, L, H, C]. Output type matches q. Backend dispatch lives
|
||||
in comfy.ldm.modules.attention.optimized_attention; we just build cu_seqlens
|
||||
from the layouts.
|
||||
"""
|
||||
s, q_seqlen, q_feats = _layout_to_feats(q)
|
||||
_, kv_seqlen, k_feats = _layout_to_feats(k)
|
||||
_, _, v_feats = _layout_to_feats(v)
|
||||
heads = q_feats.shape[1]
|
||||
|
||||
def sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv,
|
||||
window_size: int,
|
||||
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
||||
):
|
||||
device = q_feats.device
|
||||
cu_seqlens_q = _build_cu_seqlens(q_seqlen, device)
|
||||
cu_seqlens_kv = _build_cu_seqlens(kv_seqlen, device)
|
||||
|
||||
serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
|
||||
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
|
||||
if serialization_spatial_cache is None:
|
||||
fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
|
||||
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
|
||||
else:
|
||||
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
|
||||
out = optimized_attention(
|
||||
q_feats, k_feats, v_feats, heads,
|
||||
cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv,
|
||||
max_seqlen_q=max(q_seqlen), max_seqlen_kv=max(kv_seqlen),
|
||||
skip_reshape=True, skip_output_reshape=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if s is not None:
|
||||
return s.replace(out)
|
||||
N, L = q.shape[:2]
|
||||
return out.reshape(N, L, heads, -1)
|
||||
|
||||
|
||||
def sparse_windowed_self_attention(qkv, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0)):
|
||||
"""Windowed sparse self-attention. Partitions voxels into windows via spatial
|
||||
sort, then runs varlen attention with one sequence per non-empty window."""
|
||||
cache_name = f'windowed_attention_{window_size}_{shift_window}'
|
||||
cache = qkv.get_spatial_cache(cache_name)
|
||||
if cache is None:
|
||||
cache = calc_window_partition(qkv, window_size, shift_window)
|
||||
qkv.register_spatial_cache(cache_name, cache)
|
||||
fwd_indices, bwd_indices, seq_lens = cache
|
||||
|
||||
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
||||
heads = qkv_feats.shape[2]
|
||||
q, k, v = qkv_feats.unbind(dim=1) # each [M, H, C]
|
||||
heads = q.shape[1]
|
||||
device = q.device
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
q, k, v = qkv_feats.unbind(dim=1)
|
||||
q = q.unsqueeze(0) # [1, M, H, C]
|
||||
k = k.unsqueeze(0) # [1, M, H, C]
|
||||
v = v.unsqueeze(0) # [1, M, H, C]
|
||||
#out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
if 'flash_attn' not in globals():
|
||||
import flash_attn
|
||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
|
||||
else:
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
|
||||
out = out[bwd_indices] # [T, H, C]
|
||||
cu_seqlens = _build_cu_seqlens(seq_lens, device)
|
||||
max_seqlen = int(seq_lens.max())
|
||||
|
||||
out = optimized_attention(
|
||||
q, k, v, heads,
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen,
|
||||
skip_reshape=True, skip_output_reshape=True,
|
||||
)
|
||||
out = out[bwd_indices]
|
||||
return qkv.replace(out)
|
||||
|
||||
|
||||
def calc_window_partition(
|
||||
tensor,
|
||||
window_size: Union[int, Tuple[int, ...]],
|
||||
shift_window: Union[int, Tuple[int, ...]] = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
|
||||
|
||||
):
|
||||
"""Returns (fwd_indices, bwd_indices, seq_lens) for window partitioning."""
|
||||
DIM = tensor.coords.shape[1] - 1
|
||||
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
|
||||
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
|
||||
@ -136,139 +118,6 @@ def calc_window_partition(
|
||||
bwd_indices = torch.empty_like(fwd_indices)
|
||||
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
|
||||
seq_lens = torch.bincount(shifted_indices)
|
||||
mask = seq_lens != 0
|
||||
seq_lens = seq_lens[mask]
|
||||
seq_lens = seq_lens[seq_lens != 0]
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if 'xops' not in globals():
|
||||
import xformers.ops as xops
|
||||
attn_func_args = {
|
||||
'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
|
||||
}
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
attn_func_args = {
|
||||
'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
|
||||
'max_seqlen': torch.max(seq_lens)
|
||||
}
|
||||
|
||||
return fwd_indices, bwd_indices, seq_lens, attn_func_args
|
||||
|
||||
|
||||
def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||
q=None
|
||||
arg_names_dict = {
|
||||
1: ['qkv'],
|
||||
2: ['q', 'kv'],
|
||||
3: ['q', 'k', 'v']
|
||||
}
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
for key in arg_names_dict[num_all_args][len(args):]:
|
||||
assert key in kwargs, f"Missing argument {key}"
|
||||
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||
device = qkv.device
|
||||
|
||||
s = qkv
|
||||
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
|
||||
kv_seqlen = q_seqlen
|
||||
qkv = qkv.feats # [T, 3, H, C]
|
||||
|
||||
elif num_all_args == 2:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
kv = args[1] if len(args) > 1 else kwargs['kv']
|
||||
device = q.device
|
||||
|
||||
if isinstance(q, VarLenTensor):
|
||||
s = q
|
||||
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||
q = q.feats # [T_Q, H, C]
|
||||
else:
|
||||
s = None
|
||||
N, L, H, C = q.shape
|
||||
q_seqlen = [L] * N
|
||||
q = q.reshape(N * L, H, C) # [T_Q, H, C]
|
||||
|
||||
if isinstance(kv, VarLenTensor):
|
||||
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
|
||||
kv = kv.feats # [T_KV, 2, H, C]
|
||||
else:
|
||||
N, L, _, H, C = kv.shape
|
||||
kv_seqlen = [L] * N
|
||||
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
|
||||
|
||||
elif num_all_args == 3:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
k = args[1] if len(args) > 1 else kwargs['k']
|
||||
v = args[2] if len(args) > 2 else kwargs['v']
|
||||
device = q.device
|
||||
|
||||
if isinstance(q, VarLenTensor):
|
||||
s = q
|
||||
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||
q = q.feats # [T_Q, H, Ci]
|
||||
else:
|
||||
s = None
|
||||
N, L, H, CI = q.shape
|
||||
q_seqlen = [L] * N
|
||||
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
|
||||
|
||||
if isinstance(k, VarLenTensor):
|
||||
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
|
||||
k = k.feats # [T_KV, H, Ci]
|
||||
v = v.feats # [T_KV, H, Co]
|
||||
else:
|
||||
N, L, H, CI, CO = *k.shape, v.shape[-1]
|
||||
kv_seqlen = [L] * N
|
||||
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
|
||||
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
|
||||
|
||||
# TODO: change
|
||||
if q is not None:
|
||||
heads = q
|
||||
else:
|
||||
heads = qkv
|
||||
heads = heads.shape[2]
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if 'xops' not in globals():
|
||||
import xformers.ops as xops
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=1)
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
|
||||
out = xops.memory_efficient_attention(q, k, v, mask)[0]
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
if 'flash_attn' not in globals():
|
||||
import flash_attn
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args in [2, 3]:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args == 1:
|
||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
|
||||
elif num_all_args == 2:
|
||||
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
elif num_all_args == 3:
|
||||
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
|
||||
elif optimized_attention.__name__ == "attention_pytorch":
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args in [2, 3]:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
else:
|
||||
cu_seqlens_kv = cu_seqlens_q
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=1)
|
||||
out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen),
|
||||
skip_reshape=True, skip_output_reshape=True)
|
||||
|
||||
if s is not None:
|
||||
return s.replace(out)
|
||||
else:
|
||||
return out.reshape(N, L, H, -1)
|
||||
return fwd_indices, bwd_indices, seq_lens
|
||||
|
||||
@ -4,11 +4,12 @@ import torch.nn as nn
|
||||
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
|
||||
from typing import Optional, Tuple, Literal, Union, List
|
||||
from comfy.ldm.trellis2.attention import (
|
||||
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
||||
sparse_windowed_self_attention, sparse_attention, dense_attention
|
||||
)
|
||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||
|
||||
|
||||
class SparseGELU(nn.GELU):
|
||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
@ -31,22 +32,16 @@ class LayerNorm32(nn.LayerNorm):
|
||||
x = x.to(dtype=torch.float32)
|
||||
o = super().forward(x)
|
||||
return o.to(dtype=x_dtype)
|
||||
|
||||
|
||||
|
||||
class SparseMultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, device, dtype):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
x_type = x.dtype
|
||||
x = x.float()
|
||||
if isinstance(x, VarLenTensor):
|
||||
x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale)
|
||||
else:
|
||||
x = F.normalize(x, dim=-1) * self.gamma * self.scale
|
||||
return x.to(x_type)
|
||||
return x.replace(F.rms_norm(x.feats, (x.feats.shape[-1],)) * self.gamma)
|
||||
return F.rms_norm(x, (x.shape[-1],)) * self.gamma
|
||||
|
||||
class SparseRotaryPositionEmbedder(nn.Module):
|
||||
def __init__(
|
||||
@ -84,12 +79,6 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
self.freqs = self.freqs.to(indices.device)
|
||||
phases = torch.outer(indices, self.freqs)
|
||||
phases = torch.polar(torch.ones_like(phases), phases)
|
||||
return phases
|
||||
|
||||
def forward(self, q, k=None):
|
||||
cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}'
|
||||
freqs_cis = q.get_spatial_cache(cache_name)
|
||||
@ -110,27 +99,10 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
||||
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
|
||||
return q.replace(q_feats), k.replace(k_feats)
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_rotated = x_complex * phases.unsqueeze(-2)
|
||||
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
||||
return x_embed
|
||||
|
||||
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
||||
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||
if torch.is_complex(phases):
|
||||
phases = phases.to(torch.complex64)
|
||||
else:
|
||||
phases = phases.to(torch.float32)
|
||||
if phases.shape[-1] < self.head_dim // 2:
|
||||
padn = self.head_dim // 2 - phases.shape[-1]
|
||||
phases = torch.cat([phases, torch.polar(
|
||||
torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32),
|
||||
torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32)
|
||||
)], dim=-1)
|
||||
return phases
|
||||
def forward(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
return self._get_freqs_cis(coords) # [L, head_dim/2, 2, 2]
|
||||
|
||||
class SparseMultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
@ -198,50 +170,53 @@ class SparseMultiHeadAttention(nn.Module):
|
||||
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
|
||||
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
|
||||
|
||||
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
|
||||
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None,
|
||||
transformer_options=None) -> SparseTensor:
|
||||
if self._type == "self":
|
||||
dtype = next(self.to_qkv.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
qkv = self._linear(self.to_qkv, x)
|
||||
qkv = self._fused_pre(qkv, num_fused=3)
|
||||
if self.qk_rms_norm or self.use_rope:
|
||||
if self.attn_mode == "full":
|
||||
q, k, v = qkv.unbind(dim=-3)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
if self.use_rope:
|
||||
q, k = self.rope(q, k)
|
||||
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
|
||||
if self.attn_mode == "full":
|
||||
h = sparse_scaled_dot_product_attention(qkv)
|
||||
elif self.attn_mode == "windowed":
|
||||
h = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv, self.window_size, shift_window=self.shift_window
|
||||
)
|
||||
elif self.attn_mode == "double_windowed":
|
||||
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
|
||||
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
|
||||
h0 = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv0, self.window_size, shift_window=(0, 0, 0)
|
||||
)
|
||||
h1 = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
|
||||
)
|
||||
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
|
||||
h = sparse_attention(q, k, v, transformer_options=transformer_options)
|
||||
else:
|
||||
# Windowed paths take packed qkv; preserve any per-head norm/rope.
|
||||
if self.qk_rms_norm or self.use_rope:
|
||||
q, k, v = qkv.unbind(dim=-3)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
if self.use_rope:
|
||||
q, k = self.rope(q, k)
|
||||
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
|
||||
if self.attn_mode == "windowed":
|
||||
h = sparse_windowed_self_attention(
|
||||
qkv, self.window_size, shift_window=self.shift_window
|
||||
)
|
||||
elif self.attn_mode == "double_windowed":
|
||||
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
|
||||
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
|
||||
h0 = sparse_windowed_self_attention(
|
||||
qkv0, self.window_size, shift_window=(0, 0, 0)
|
||||
)
|
||||
h1 = sparse_windowed_self_attention(
|
||||
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
|
||||
)
|
||||
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
|
||||
else:
|
||||
q = self._linear(self.to_q, x)
|
||||
q = self._reshape_chs(q, (self.num_heads, -1))
|
||||
dtype = next(self.to_kv.parameters()).dtype
|
||||
context = context.to(dtype)
|
||||
kv = self._linear(self.to_kv, context)
|
||||
kv = self._fused_pre(kv, num_fused=2)
|
||||
k, v = kv.unbind(dim=-3)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k, v = kv.unbind(dim=-3)
|
||||
k = self.k_rms_norm(k)
|
||||
h = sparse_scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = sparse_scaled_dot_product_attention(q, kv)
|
||||
h = sparse_attention(q, k, v, transformer_options=transformer_options)
|
||||
h = self._reshape_chs(h, (-1,))
|
||||
h = self._linear(self.to_out, h)
|
||||
return h
|
||||
@ -265,9 +240,9 @@ class ProjectAttentionSparse(nn.Module):
|
||||
self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: SparseTensor, context) -> SparseTensor:
|
||||
def forward(self, x: SparseTensor, context, transformer_options=None) -> SparseTensor:
|
||||
global_ctx, proj_in = _split_proj_context(context)
|
||||
global_out = self.cross_attn_block(x, global_ctx)
|
||||
global_out = self.cross_attn_block(x, global_ctx, transformer_options=transformer_options)
|
||||
if isinstance(proj_in, tuple):
|
||||
proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1)
|
||||
proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype))
|
||||
@ -282,9 +257,9 @@ class ProjectAttentionDense(nn.Module):
|
||||
self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, context) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, context, transformer_options=None) -> torch.Tensor:
|
||||
global_ctx, proj_in = _split_proj_context(context)
|
||||
global_out = self.cross_attn_block(x, global_ctx)
|
||||
global_out = self.cross_attn_block(x, global_ctx, transformer_options=transformer_options)
|
||||
if isinstance(proj_in, tuple):
|
||||
proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1)
|
||||
proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype))
|
||||
@ -367,32 +342,36 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor:
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h * (1 + scale_msa) + shift_msa
|
||||
h = self.self_attn(h)
|
||||
h = h * gate_msa
|
||||
x = x + h
|
||||
# Fuse the (mul + add) and (mul + residual) pairs into addcmul so the
|
||||
# mod/shift broadcasts hit one kernel each instead of two.
|
||||
b_map = x.batch_boardcast_map
|
||||
|
||||
h_feats = self.norm1(x.feats)
|
||||
h_feats = torch.addcmul(shift_msa[b_map], h_feats, (1 + scale_msa)[b_map])
|
||||
h = self.self_attn(x.replace(h_feats), transformer_options=transformer_options)
|
||||
x = x.replace(torch.addcmul(x.feats, h.feats, gate_msa[b_map]))
|
||||
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
if self.image_attn_mode == "global":
|
||||
global_ctx, _ = _split_proj_context(context)
|
||||
h = self.cross_attn(h, global_ctx)
|
||||
h = self.cross_attn(h, global_ctx, transformer_options=transformer_options)
|
||||
else:
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = x.replace(self.norm3(x.feats))
|
||||
h = h * (1 + scale_mlp) + shift_mlp
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp
|
||||
h = self.cross_attn(h, context, transformer_options=transformer_options)
|
||||
x = x + h
|
||||
|
||||
h_feats = self.norm3(x.feats)
|
||||
h_feats = torch.addcmul(shift_mlp[b_map], h_feats, (1 + scale_mlp)[b_map])
|
||||
h = self.mlp(x.replace(h_feats))
|
||||
x = x.replace(torch.addcmul(x.feats, h.feats, gate_mlp[b_map]))
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor:
|
||||
return self._forward(x, mod, context)
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor:
|
||||
return self._forward(x, mod, context, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class SLatFlowModel(nn.Module):
|
||||
@ -480,24 +459,21 @@ class SLatFlowModel(nn.Module):
|
||||
t: torch.Tensor,
|
||||
cond: Union[torch.Tensor, List[torch.Tensor]],
|
||||
concat_cond: Optional[SparseTensor] = None,
|
||||
**kwargs
|
||||
transformer_options=None,
|
||||
**kwargs,
|
||||
) -> SparseTensor:
|
||||
if concat_cond is not None:
|
||||
x = sparse_cat([x, concat_cond], dim=-1)
|
||||
if isinstance(cond, list):
|
||||
cond = VarLenTensor.from_tensor_list(cond)
|
||||
|
||||
dtype = next(self.input_layer.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
h = self.input_layer(x)
|
||||
t = t.to(dtype)
|
||||
t_embedder = self.t_embedder.to(dtype)
|
||||
t_emb = t_embedder(t, out_dtype = t.dtype)
|
||||
t_emb = self.t_embedder(t, out_dtype=t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
|
||||
for block in self.blocks:
|
||||
h = block(h, t_emb, cond)
|
||||
h = block(h, t_emb, cond, transformer_options=transformer_options)
|
||||
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.out_layer(h)
|
||||
@ -566,40 +542,34 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None,
|
||||
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
|
||||
B, L, C = x.shape
|
||||
if self._type == "self":
|
||||
x = x.to(next(self.to_qkv.parameters()).dtype)
|
||||
qkv = self.to_qkv(x)
|
||||
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
||||
|
||||
if self.attn_mode == "full":
|
||||
if self.qk_rms_norm or self.use_rope:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
if self.use_rope:
|
||||
assert phases is not None, "Phases must be provided for RoPE"
|
||||
q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
|
||||
k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
|
||||
h = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = scaled_dot_product_attention(qkv)
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
if self.use_rope:
|
||||
assert phases is not None, "Phases must be provided for RoPE"
|
||||
# phases is [L, head_dim/2, 2, 2]; broadcast to [1, L, 1, ...]
|
||||
# to align with q/k of shape [B, L, H, head_dim].
|
||||
f_cis = phases.unsqueeze(0).unsqueeze(2)
|
||||
q, k = apply_rope(q, k, f_cis)
|
||||
h = dense_attention(q, k, v, transformer_options=transformer_options)
|
||||
else:
|
||||
Lkv = context.shape[1]
|
||||
q = self.to_q(x)
|
||||
context = context.to(next(self.to_kv.parameters()).dtype)
|
||||
kv = self.to_kv(context)
|
||||
q = q.reshape(B, L, self.num_heads, -1)
|
||||
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
||||
k, v = kv.unbind(dim=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k, v = kv.unbind(dim=2)
|
||||
k = self.k_rms_norm(k)
|
||||
h = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = scaled_dot_product_attention(q, kv)
|
||||
h = dense_attention(q, k, v, transformer_options=transformer_options)
|
||||
h = h.reshape(B, L, -1)
|
||||
h = self.to_out(h)
|
||||
return h
|
||||
@ -677,32 +647,39 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context,
|
||||
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = self.norm1(x)
|
||||
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||
h = self.self_attn(h, phases=phases)
|
||||
h = h * gate_msa.unsqueeze(1)
|
||||
x = x + h
|
||||
shift_msa = shift_msa.unsqueeze(1)
|
||||
scale_msa = scale_msa.unsqueeze(1)
|
||||
gate_msa = gate_msa.unsqueeze(1)
|
||||
shift_mlp = shift_mlp.unsqueeze(1)
|
||||
scale_mlp = scale_mlp.unsqueeze(1)
|
||||
gate_mlp = gate_mlp.unsqueeze(1)
|
||||
|
||||
h = torch.addcmul(shift_msa, self.norm1(x), 1 + scale_msa)
|
||||
h = self.self_attn(h, phases=phases, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, h, gate_msa)
|
||||
|
||||
h = self.norm2(x)
|
||||
if self.image_attn_mode == "global":
|
||||
global_ctx, _ = _split_proj_context(context)
|
||||
h = self.cross_attn(h, global_ctx)
|
||||
h = self.cross_attn(h, global_ctx, transformer_options=transformer_options)
|
||||
else:
|
||||
h = self.cross_attn(h, context)
|
||||
h = self.cross_attn(h, context, transformer_options=transformer_options)
|
||||
x = x + h
|
||||
h = self.norm3(x)
|
||||
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
|
||||
h = torch.addcmul(shift_mlp, self.norm3(x), 1 + scale_mlp)
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp.unsqueeze(1)
|
||||
x = x + h
|
||||
x = torch.addcmul(x, h, gate_mlp)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return self._forward(x, mod, context, phases)
|
||||
def forward(self, x: torch.Tensor, mod: torch.Tensor, context,
|
||||
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
|
||||
return self._forward(x, mod, context, phases, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class SparseStructureFlowModel(nn.Module):
|
||||
@ -792,18 +769,18 @@ class SparseStructureFlowModel(nn.Module):
|
||||
|
||||
self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor,
|
||||
transformer_options=None) -> torch.Tensor:
|
||||
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
||||
|
||||
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
|
||||
|
||||
h = h.to(next(self.input_layer.parameters()).dtype)
|
||||
h = self.input_layer(h)
|
||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||
t_emb = self.t_embedder(t, out_dtype=t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
for block in self.blocks:
|
||||
h = block(h, t_emb, cond, self.rope_phases)
|
||||
h = block(h, t_emb, cond, self.rope_phases, transformer_options=transformer_options)
|
||||
h = F.layer_norm(h, h.shape[-1:])
|
||||
h = self.out_layer(h)
|
||||
|
||||
@ -820,7 +797,7 @@ def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
||||
|
||||
|
||||
# Pixal3D ProjGrid math — port of upstream's ProjGrid + project_points_to_image_batch.
|
||||
# World frame uses world Y as depth (Blender convention), camera looks along -Z local;
|
||||
# World frame uses world Y as depth, camera looks along -Z local;
|
||||
# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2)
|
||||
# with sensor_width = 32mm.
|
||||
|
||||
@ -927,7 +904,7 @@ def _back_project_to_tokens(
|
||||
if not mask.any():
|
||||
continue
|
||||
p = coords_world[mask].unsqueeze(0)
|
||||
uv, depth, valid = _project_points_to_image(
|
||||
uv, _, _ = _project_points_to_image(
|
||||
p, transform_matrix[b:b+1], camera_angle_x[b:b+1], image_resolution)
|
||||
uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0
|
||||
# padding_mode='border' is load-bearing: masking out-of-frame voxels confuses
|
||||
@ -937,7 +914,7 @@ def _back_project_to_tokens(
|
||||
out[mask] = sampled
|
||||
return out
|
||||
else:
|
||||
uv, depth, valid = _project_points_to_image(
|
||||
uv, _, _ = _project_points_to_image(
|
||||
coords_world, transform_matrix, camera_angle_x, image_resolution)
|
||||
uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0
|
||||
sampled = _sample_features(feature_map, uv_ndc)
|
||||
@ -945,21 +922,6 @@ def _back_project_to_tokens(
|
||||
return out
|
||||
|
||||
|
||||
def _pack_per_voxel_scalar(proj_pack: Optional[dict], key: str, eval_batch: int, device) -> torch.Tensor:
|
||||
if proj_pack is None or key not in proj_pack:
|
||||
return torch.ones((eval_batch,), device=device, dtype=torch.float32)
|
||||
t = proj_pack[key].to(device=device, dtype=torch.float32)
|
||||
if t.ndim == 0:
|
||||
return t.expand(eval_batch).clone()
|
||||
return _expand_pack(t, eval_batch)
|
||||
|
||||
|
||||
def _expand_pack(t: torch.Tensor, eval_batch: int) -> torch.Tensor:
|
||||
if eval_batch == t.shape[0]:
|
||||
return t
|
||||
if eval_batch % t.shape[0] != 0:
|
||||
raise ValueError(f"eval batch {eval_batch} is not a multiple of pack batch {t.shape[0]}")
|
||||
return t.repeat((eval_batch // t.shape[0],) + (1,) * (t.ndim - 1))
|
||||
|
||||
|
||||
def _select_stage_entry(proj_pack: dict, stage: Optional[str]):
|
||||
@ -973,53 +935,96 @@ def _select_stage_entry(proj_pack: dict, stage: Optional[str]):
|
||||
raise ValueError(f"proj_feat_pack has no usable feature_map (stage={stage!r})")
|
||||
|
||||
|
||||
def _build_proj_cond(global_cond: torch.Tensor, image_attn_mode: str, proj_pack: Optional[dict],
|
||||
coords_world: torch.Tensor, batch_ids: Optional[torch.Tensor] = None,
|
||||
eval_batch: Optional[int] = None,
|
||||
proj_in_channels: Optional[int] = None,
|
||||
stage: Optional[str] = None,
|
||||
cond_or_uncond: Optional[list] = None):
|
||||
if image_attn_mode == "global":
|
||||
return global_cond
|
||||
if proj_pack is None:
|
||||
raise ValueError(f"image_attn_mode={image_attn_mode!r} but proj_feat_pack is missing")
|
||||
device = coords_world.device
|
||||
def compute_stage_proj_feats(
|
||||
proj_pack: dict,
|
||||
stage: str,
|
||||
coords: Optional[torch.Tensor] = None,
|
||||
coord_resolution: Optional[int] = None,
|
||||
dense_grid_resolution: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
device=None,
|
||||
) -> torch.Tensor:
|
||||
"""Back-project a Pixal3D stage's feature maps onto its target voxel/grid coords.
|
||||
|
||||
For sparse (shape / texture) stages: pass ``coords`` (with ``coord_resolution``).
|
||||
Returns ``[N_voxels, C]`` per-voxel features with channel count =
|
||||
LR channels + optional HR channels.
|
||||
|
||||
For the dense SS stage: pass ``dense_grid_resolution`` (16) + ``batch_size``.
|
||||
Returns ``[B, R^3, C]`` features for the dense grid.
|
||||
|
||||
"""
|
||||
if device is None:
|
||||
device = coords.device if coords is not None else proj_pack["mesh_scale"].device
|
||||
mesh_scale = proj_pack["mesh_scale"].to(device)
|
||||
T = proj_pack["transform_matrix"].to(device)
|
||||
cam_angle = proj_pack["camera_angle_x"].to(device)
|
||||
feat_map_lr, feat_map_hr, image_resolution = _select_stage_entry(proj_pack, stage)
|
||||
feat_map_lr = feat_map_lr.to(device)
|
||||
if feat_map_hr is not None:
|
||||
feat_map_hr = feat_map_hr.to(device)
|
||||
if eval_batch is not None:
|
||||
T = _expand_pack(T, eval_batch)
|
||||
cam_angle = _expand_pack(cam_angle, eval_batch) if cam_angle.ndim >= 1 else cam_angle
|
||||
feat_map_lr = _expand_pack(feat_map_lr, eval_batch)
|
||||
if feat_map_hr is not None:
|
||||
feat_map_hr = _expand_pack(feat_map_hr, eval_batch)
|
||||
# Channel-count check against the trained proj_linear input. If HR is present, the
|
||||
# block expects (LR_channels + HR_channels) since we concat the sampled features.
|
||||
expected_channels = feat_map_lr.shape[1] + (feat_map_hr.shape[1] if feat_map_hr is not None else 0)
|
||||
if proj_in_channels is not None and expected_channels != proj_in_channels:
|
||||
|
||||
if coords is not None:
|
||||
if coord_resolution is None:
|
||||
raise ValueError("compute_stage_proj_feats: coord_resolution required when coords is given")
|
||||
coords_world, batch_ids = _coords_to_proj_world(coords, coord_resolution, mesh_scale)
|
||||
else:
|
||||
if dense_grid_resolution is None or batch_size is None:
|
||||
raise ValueError("compute_stage_proj_feats: dense_grid_resolution + batch_size required for dense path")
|
||||
coords_world = _dense_grid_proj_world(dense_grid_resolution, mesh_scale, batch_size,
|
||||
device=device, dtype=torch.float32)
|
||||
batch_ids = None
|
||||
|
||||
proj_lr = _back_project_to_tokens(coords_world, feat_map_lr, T, cam_angle,
|
||||
image_resolution=image_resolution, batch_ids=batch_ids)
|
||||
if feat_map_hr is not None:
|
||||
proj_hr = _back_project_to_tokens(coords_world, feat_map_hr, T, cam_angle,
|
||||
image_resolution=image_resolution, batch_ids=batch_ids)
|
||||
return torch.cat([proj_lr, proj_hr], dim=-1)
|
||||
return proj_lr
|
||||
|
||||
|
||||
def _shape_proj_cond(global_cond: torch.Tensor, image_attn_mode: str,
|
||||
proj_feats: Optional[torch.Tensor],
|
||||
batch_ids: Optional[torch.Tensor] = None,
|
||||
eval_batch: Optional[int] = None,
|
||||
logical_batch: Optional[int] = None,
|
||||
proj_in_channels: Optional[int] = None,
|
||||
stage: Optional[str] = None,
|
||||
cond_or_uncond: Optional[list] = None,
|
||||
has_hr: bool = False):
|
||||
"""Take pre-computed per-token proj features (from compute_stage_proj_feats),
|
||||
apply CFG-batch duplication + uncond-slot zeroing, and wrap into the
|
||||
``{"global", "proj"}`` context dict consumed by ProjectAttention.
|
||||
|
||||
proj_feats shape:
|
||||
sparse (shape/texture): [N_voxels, C] (batch_ids gives per-voxel batch)
|
||||
dense (SS): [B, N, C]
|
||||
"""
|
||||
if image_attn_mode == "global":
|
||||
return global_cond
|
||||
if proj_feats is None:
|
||||
raise ValueError(f"image_attn_mode={image_attn_mode!r} but trellis2_proj_feats is missing — "
|
||||
f"the stage setup node (or Pixal3DConditioning for SS) should have computed it.")
|
||||
if proj_in_channels is not None and proj_feats.shape[-1] != proj_in_channels:
|
||||
hint = ""
|
||||
if feat_map_hr is None and expected_channels < proj_in_channels:
|
||||
if not has_hr and proj_feats.shape[-1] < proj_in_channels:
|
||||
hint = (" — feature_map_hr is missing for this stage. Connect a NAFModel "
|
||||
"input to Pixal3DConditioning; the shape/texture stages of this "
|
||||
"checkpoint need a NAF-upsampled HR feature map.")
|
||||
raise ValueError(
|
||||
f"proj_feat_pack[{stage!r}] has LR={feat_map_lr.shape[1]} "
|
||||
f"+ HR={feat_map_hr.shape[1] if feat_map_hr is not None else 0} "
|
||||
f"= {expected_channels} channels, sub-model expects {proj_in_channels}.{hint}"
|
||||
f"proj_feats for stage {stage!r} has {proj_feats.shape[-1]} channels, "
|
||||
f"sub-model expects {proj_in_channels}.{hint}"
|
||||
)
|
||||
proj_feats_lr = _back_project_to_tokens(coords_world, feat_map_lr, T, cam_angle,
|
||||
image_resolution=image_resolution,
|
||||
batch_ids=batch_ids)
|
||||
if feat_map_hr is not None:
|
||||
proj_feats_hr = _back_project_to_tokens(coords_world, feat_map_hr, T, cam_angle,
|
||||
image_resolution=image_resolution,
|
||||
batch_ids=batch_ids)
|
||||
proj_feats = torch.cat([proj_feats_lr, proj_feats_hr], dim=-1)
|
||||
else:
|
||||
proj_feats = proj_feats_lr
|
||||
|
||||
# CFG-duplicate proj_feats to match the model's eval batch.
|
||||
if eval_batch is not None and logical_batch is not None and eval_batch > logical_batch:
|
||||
repeats = eval_batch // logical_batch
|
||||
if batch_ids is None:
|
||||
proj_feats = proj_feats.repeat((repeats,) + (1,) * (proj_feats.ndim - 1))
|
||||
else:
|
||||
proj_feats = proj_feats.repeat((repeats, 1))
|
||||
|
||||
# Mirror upstream's neg_cond by zeroing proj for any uncond batch slot.
|
||||
if cond_or_uncond is not None and eval_batch is not None:
|
||||
uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1]
|
||||
@ -1056,10 +1061,6 @@ class Trellis2(nn.Module):
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operations = operations or nn
|
||||
# for some reason it passes num_heads = -1
|
||||
if num_heads == -1:
|
||||
num_heads = 12
|
||||
args = {
|
||||
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||
@ -1082,101 +1083,79 @@ class Trellis2(nn.Module):
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **struct_proj_kwargs, **args)
|
||||
else:
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **tex_proj_kwargs, **args)
|
||||
self.guidance_interval = [0.6, 1.0]
|
||||
self.guidance_interval_txt = [0.6, 0.9]
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond")
|
||||
model_options = {}
|
||||
if hasattr(self, "meta"):
|
||||
model_options = self.meta
|
||||
timestep = timestep.to(x.dtype)
|
||||
embeds = kwargs.get("embeds")
|
||||
if embeds is None:
|
||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||
|
||||
is_1024 = True#self.img2shape.resolution == 1024
|
||||
coords = model_options.get("coords", None)
|
||||
coord_counts = model_options.get("coord_counts", None)
|
||||
mode = model_options.get("generation_mode", "structure_generation")
|
||||
proj_feat_pack = model_options.get("proj_feat_pack", None)
|
||||
coord_resolution = model_options.get("coord_resolution", None)
|
||||
# Per-stage cascade metadata
|
||||
coords = kwargs.get("trellis2_coords")
|
||||
coord_counts = kwargs.get("trellis2_coord_counts")
|
||||
mode = kwargs.get("trellis2_generation_mode", "structure_generation")
|
||||
proj_feat_pack = kwargs.get("proj_feat_pack")
|
||||
# Pre-computed per-stage back-projected features
|
||||
proj_feats = kwargs.get("trellis2_proj_feats")
|
||||
|
||||
is_512_run = False
|
||||
is_first_shape_pass = False
|
||||
if mode == "shape_generation_512":
|
||||
is_512_run = True
|
||||
is_first_shape_pass = True
|
||||
mode = "shape_generation"
|
||||
|
||||
if coords is not None:
|
||||
if x.ndim == 4:
|
||||
x = x.squeeze(-1).transpose(1, 2)
|
||||
not_struct_mode = True
|
||||
x = x.squeeze(-1).transpose(1, 2)
|
||||
is_sparse_mode = True
|
||||
else:
|
||||
mode = "structure_generation"
|
||||
not_struct_mode = False
|
||||
is_sparse_mode = False
|
||||
|
||||
if x.size(-1) == 16 and x.size(-2) == 16:
|
||||
mode = "structure_generation"
|
||||
not_struct_mode = False
|
||||
is_sparse_mode = False
|
||||
|
||||
if not not_struct_mode:
|
||||
if not is_sparse_mode:
|
||||
bsz = x.size(0)
|
||||
x = x[:, :8]
|
||||
x = x.view(bsz, 8, 16, 16, 16)
|
||||
|
||||
if is_1024 and not_struct_mode and not is_512_run:
|
||||
if is_sparse_mode and not is_first_shape_pass:
|
||||
context = embeds
|
||||
|
||||
sigmas = transformer_options.get("sigmas")[0].item()
|
||||
if sigmas < 1.00001:
|
||||
timestep *= 1000.0
|
||||
if is_sparse_mode:
|
||||
t_eval = timestep
|
||||
c_eval = context
|
||||
|
||||
if context.size(0) > 1:
|
||||
cond = context.chunk(2)[1]
|
||||
else:
|
||||
cond = context
|
||||
|
||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||
|
||||
if not_struct_mode:
|
||||
orig_bsz = x.shape[0]
|
||||
rule = txt_rule if mode == "texture_generation" else shape_rule
|
||||
|
||||
# CFG Bypass Slicing
|
||||
if rule and orig_bsz > 1:
|
||||
half = orig_bsz // 2
|
||||
x_eval = x[half:]
|
||||
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
|
||||
c_eval = cond
|
||||
else:
|
||||
x_eval = x
|
||||
t_eval = timestep
|
||||
c_eval = context
|
||||
|
||||
B, N, C = x_eval.shape
|
||||
B, N, C = x.shape
|
||||
|
||||
# Vectorized SparseTensor Construction
|
||||
if mode in ["shape_generation", "texture_generation"]:
|
||||
if coord_counts is not None:
|
||||
logical_batch = coord_counts.shape[0]
|
||||
# Duplicate coords if CFG is active
|
||||
# Duplicate sparse coords when the sampler asks for >1 cond
|
||||
# (CFG or otherwise). Each duplicate is offset along col 0
|
||||
# so SparseTensor sees a fresh logical batch.
|
||||
if B > logical_batch:
|
||||
c_pos = coords.clone()
|
||||
c_pos[:, 0] += logical_batch
|
||||
batched_coords = torch.cat([coords, c_pos], dim=0)
|
||||
counts_eval = torch.cat([coord_counts, coord_counts], dim=0)
|
||||
reps = B // logical_batch
|
||||
c_copies = []
|
||||
for i in range(reps):
|
||||
c = coords.clone()
|
||||
c[:, 0] += i * logical_batch
|
||||
c_copies.append(c)
|
||||
batched_coords = torch.cat(c_copies, dim=0)
|
||||
counts_eval = coord_counts.repeat(reps)
|
||||
else:
|
||||
batched_coords = coords
|
||||
counts_eval = coord_counts
|
||||
|
||||
# Create boolean mask [B, N] to drop the padded zeros instantly
|
||||
# Boolean mask [B, N] to drop the padded zeros instantly
|
||||
mask = torch.arange(N, device=x.device).unsqueeze(0) < counts_eval.unsqueeze(1)
|
||||
feats_flat = x_eval[mask]
|
||||
feats_flat = x[mask]
|
||||
else:
|
||||
feats_flat = x_eval.reshape(-1, C)
|
||||
coords_list =[]
|
||||
feats_flat = x.reshape(-1, C)
|
||||
coords_list = []
|
||||
for i in range(B):
|
||||
c = coords.clone()
|
||||
c[:, 0] = i
|
||||
@ -1185,35 +1164,42 @@ class Trellis2(nn.Module):
|
||||
mask = None
|
||||
else:
|
||||
batched_coords = coords
|
||||
feats_flat = x_eval
|
||||
feats_flat = x
|
||||
mask = None
|
||||
|
||||
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
||||
x_st = SparseTensor(
|
||||
feats=feats_flat,
|
||||
coords=batched_coords.to(torch.int32),
|
||||
shape=torch.Size([B] + list(feats_flat.shape[1:])),
|
||||
)
|
||||
|
||||
if mode == "shape_generation":
|
||||
shape_attn = self.image_attn_mode_shape
|
||||
if shape_attn != "global":
|
||||
if coord_resolution is None:
|
||||
raise ValueError("Pixal3D shape_generation requires coord_resolution in model_options; "
|
||||
"EmptyTrellis2ShapeLatent should set it from the input voxel.")
|
||||
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device)
|
||||
xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale)
|
||||
sub_model = self.img2shape_512 if is_512_run else self.img2shape
|
||||
stage_name = "shape_512" if is_512_run else "shape_1024"
|
||||
c_eval = _build_proj_cond(c_eval, shape_attn, proj_feat_pack, xyz_world, batch_ids,
|
||||
eval_batch=B,
|
||||
sub_model = self.img2shape_512 if is_first_shape_pass else self.img2shape
|
||||
stage_name = "shape_512" if is_first_shape_pass else "shape_1024"
|
||||
# batched_coords carries CFG-doubled batch ids in col 0; per-voxel
|
||||
# batch_ids drive uncond-slot masking inside _shape_proj_cond.
|
||||
batch_ids = batched_coords[:, 0].long()
|
||||
logical_batch = coord_counts.shape[0] if coord_counts is not None else B
|
||||
has_hr = bool(proj_feat_pack and proj_feat_pack.get("stages", {})
|
||||
.get(stage_name, {}).get("feature_map_hr") is not None)
|
||||
c_eval = _shape_proj_cond(c_eval, shape_attn, proj_feats,
|
||||
batch_ids=batch_ids,
|
||||
eval_batch=B, logical_batch=logical_batch,
|
||||
proj_in_channels=sub_model.proj_in_channels,
|
||||
stage=stage_name,
|
||||
cond_or_uncond=cond_or_uncond)
|
||||
if is_512_run:
|
||||
out = self.img2shape_512(x_st, t_eval, c_eval)
|
||||
cond_or_uncond=cond_or_uncond,
|
||||
has_hr=has_hr)
|
||||
if is_first_shape_pass:
|
||||
out = self.img2shape_512(x_st, t_eval, c_eval, transformer_options=transformer_options)
|
||||
else:
|
||||
out = self.img2shape(x_st, t_eval, c_eval)
|
||||
out = self.img2shape(x_st, t_eval, c_eval, transformer_options=transformer_options)
|
||||
|
||||
elif mode == "texture_generation":
|
||||
if self.shape2txt is None:
|
||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||
slat = model_options.get("shape_slat")
|
||||
slat = kwargs.get("trellis2_shape_slat")
|
||||
if slat is None:
|
||||
raise ValueError("shape_slat can't be None")
|
||||
|
||||
@ -1227,49 +1213,40 @@ class Trellis2(nn.Module):
|
||||
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1))
|
||||
tex_attn = self.image_attn_mode_texture
|
||||
if tex_attn != "global":
|
||||
if coord_resolution is None:
|
||||
raise ValueError("Pixal3D texture_generation requires coord_resolution in model_options; "
|
||||
"EmptyTrellis2LatentTexture should set it from the input voxel.")
|
||||
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device)
|
||||
xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale)
|
||||
c_eval = _build_proj_cond(c_eval, tex_attn, proj_feat_pack, xyz_world, batch_ids,
|
||||
eval_batch=B,
|
||||
batch_ids = batched_coords[:, 0].long()
|
||||
logical_batch = coord_counts.shape[0] if coord_counts is not None else B
|
||||
has_hr = bool(proj_feat_pack and proj_feat_pack.get("stages", {})
|
||||
.get("tex_1024", {}).get("feature_map_hr") is not None)
|
||||
c_eval = _shape_proj_cond(c_eval, tex_attn, proj_feats,
|
||||
batch_ids=batch_ids,
|
||||
eval_batch=B, logical_batch=logical_batch,
|
||||
proj_in_channels=self.shape2txt.proj_in_channels,
|
||||
stage="tex_1024",
|
||||
cond_or_uncond=cond_or_uncond)
|
||||
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||
cond_or_uncond=cond_or_uncond,
|
||||
has_hr=has_hr)
|
||||
out = self.shape2txt(x_st, t_eval, c_eval, transformer_options=transformer_options)
|
||||
|
||||
else: # structure
|
||||
orig_bsz = x.shape[0]
|
||||
struct_attn = self.image_attn_mode_structure
|
||||
if shape_rule and orig_bsz > 1:
|
||||
half = orig_bsz // 2
|
||||
x_eval = x[half:]
|
||||
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
|
||||
struct_cond = cond
|
||||
if struct_attn != "global":
|
||||
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", half, x.device)
|
||||
grid_xyz = _dense_grid_proj_world(16, mesh_scale, half, device=x.device)
|
||||
struct_cond = _build_proj_cond(cond, struct_attn, proj_feat_pack, grid_xyz,
|
||||
eval_batch=half,
|
||||
proj_in_channels=self.structure_model.proj_in_channels,
|
||||
stage="ss",
|
||||
cond_or_uncond=cond_or_uncond)
|
||||
out = self.structure_model(x_eval, t_eval, struct_cond)
|
||||
out = out.repeat(2, 1, 1, 1, 1)
|
||||
else:
|
||||
struct_cond = context
|
||||
if struct_attn != "global":
|
||||
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", orig_bsz, x.device)
|
||||
grid_xyz = _dense_grid_proj_world(16, mesh_scale, orig_bsz, device=x.device)
|
||||
struct_cond = _build_proj_cond(context, struct_attn, proj_feat_pack, grid_xyz,
|
||||
eval_batch=orig_bsz,
|
||||
proj_in_channels=self.structure_model.proj_in_channels,
|
||||
stage="ss",
|
||||
cond_or_uncond=cond_or_uncond)
|
||||
out = self.structure_model(x, timestep, struct_cond)
|
||||
has_hr_ss = bool(proj_feat_pack and proj_feat_pack.get("stages", {})
|
||||
.get("ss", {}).get("feature_map_hr") is not None)
|
||||
logical_batch_ss = (
|
||||
proj_feat_pack["mesh_scale"].shape[0]
|
||||
if (proj_feat_pack is not None and torch.is_tensor(proj_feat_pack.get("mesh_scale")))
|
||||
else x.shape[0]
|
||||
)
|
||||
struct_cond = context
|
||||
if struct_attn != "global":
|
||||
struct_cond = _shape_proj_cond(context, struct_attn, proj_feats,
|
||||
batch_ids=None,
|
||||
eval_batch=x.shape[0], logical_batch=logical_batch_ss,
|
||||
proj_in_channels=self.structure_model.proj_in_channels,
|
||||
stage="ss",
|
||||
cond_or_uncond=cond_or_uncond,
|
||||
has_hr=has_hr_ss)
|
||||
out = self.structure_model(x, timestep, struct_cond, transformer_options=transformer_options)
|
||||
|
||||
if not_struct_mode:
|
||||
if is_sparse_mode:
|
||||
if mask is not None:
|
||||
# Instantly scatter the valid tokens back into a padded rectangular tensor
|
||||
padded_out = torch.zeros((B, N, out.feats.shape[-1]), device=x.device, dtype=out.feats.dtype)
|
||||
@ -1277,9 +1254,6 @@ class Trellis2(nn.Module):
|
||||
out_tensor = padded_out.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
out_tensor = out.feats.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
||||
|
||||
if rule and orig_bsz > 1:
|
||||
out_tensor = out_tensor.repeat(2, 1, 1, 1)
|
||||
return out_tensor
|
||||
else:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 0, 24))
|
||||
|
||||
@ -1646,6 +1646,17 @@ class Trellis2(BaseModel):
|
||||
out = super().extra_conds(**kwargs)
|
||||
embeds = kwargs.get("embeds")
|
||||
out["embeds"] = comfy.conds.CONDRegular(embeds)
|
||||
# CONDConstant: shared across pos/neg
|
||||
for k in ("trellis2_coords", "trellis2_coord_counts",
|
||||
"trellis2_generation_mode", "trellis2_shape_slat",
|
||||
"trellis2_proj_feats"):
|
||||
v = kwargs.get(k)
|
||||
if v is not None:
|
||||
out[k] = comfy.conds.CONDConstant(v)
|
||||
# Pixal3D's per-stage feature maps + camera params travel as a dict
|
||||
proj_feat_pack = kwargs.get("proj_feat_pack")
|
||||
if proj_feat_pack is not None:
|
||||
out["proj_feat_pack"] = comfy.conds.CONDConstant(proj_feat_pack)
|
||||
return out
|
||||
|
||||
class WAN21_FlowRVS(WAN21):
|
||||
|
||||
@ -1322,10 +1322,10 @@ class Trellis2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "trellis2"
|
||||
}
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
"multiplier": 1.0
|
||||
}
|
||||
|
||||
memory_usage_factor = 3.5
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types, io
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
from comfy.ldm.trellis2.model import _build_proj_transform_matrix, _project_points_to_image
|
||||
from comfy.ldm.trellis2.model import (
|
||||
_build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats,
|
||||
)
|
||||
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
|
||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
import comfy.model_management
|
||||
@ -14,48 +16,9 @@ import math
|
||||
import torch
|
||||
|
||||
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
||||
Pixal3DProjPack = io.Custom("PIXAL3D_PROJ_PACK")
|
||||
NAFModel = io.Custom("NAF_MODEL")
|
||||
|
||||
|
||||
# Pixal3D trains in a 90°-X-rotated grid frame (F_p). We un-rotate decoder outputs for
|
||||
# user-facing previews/meshes, then re-rotate before feeding coords back to the shape DiT.
|
||||
|
||||
def _pixal3d_unrotate_voxel_data(data: torch.Tensor) -> torch.Tensor:
|
||||
if data.ndim == 4:
|
||||
return data.flip(-1).permute(0, 1, 3, 2).contiguous()
|
||||
if data.ndim == 5:
|
||||
return data.flip(-1).permute(0, 1, 2, 4, 3).contiguous()
|
||||
raise ValueError(f"unexpected voxel shape {tuple(data.shape)}")
|
||||
|
||||
|
||||
def _pixal3d_rerotate_voxel_data(data: torch.Tensor) -> torch.Tensor:
|
||||
if data.ndim == 4:
|
||||
return data.permute(0, 1, 3, 2).flip(-1).contiguous()
|
||||
if data.ndim == 5:
|
||||
return data.permute(0, 1, 2, 4, 3).flip(-1).contiguous()
|
||||
raise ValueError(f"unexpected voxel shape {tuple(data.shape)}")
|
||||
|
||||
|
||||
def _pixal3d_unrotate_vertices(vertices: torch.Tensor) -> torch.Tensor:
|
||||
if vertices.numel() == 0:
|
||||
return vertices
|
||||
x, y, z = vertices.unbind(-1)
|
||||
return torch.stack([-x, y, -z], dim=-1).contiguous()
|
||||
|
||||
|
||||
def _pixal3d_unrotate_sparse_coords(coords: torch.Tensor, resolution: int) -> torch.Tensor:
|
||||
if coords.numel() == 0:
|
||||
return coords
|
||||
R1 = resolution - 1
|
||||
if coords.shape[-1] == 4:
|
||||
b, i, j, k = coords.unbind(-1)
|
||||
return torch.stack([b, R1 - i, j, R1 - k], dim=-1).contiguous()
|
||||
if coords.shape[-1] == 3:
|
||||
i, j, k = coords.unbind(-1)
|
||||
return torch.stack([R1 - i, j, R1 - k], dim=-1).contiguous()
|
||||
raise ValueError(f"unexpected coord shape {tuple(coords.shape)}")
|
||||
|
||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||
if len(sample_shape) == 5:
|
||||
@ -202,15 +165,21 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae):
|
||||
|
||||
resolution = int(vae.first_stage_model.resolution.item())
|
||||
# Mesh grid_size must match the actual coord resolution the upstream
|
||||
# stage was run at (1024 cascade -> 64, 1536 cascade -> 96). The VAE's
|
||||
# built-in `.resolution` buffer defaults to 1024 and is otherwise stale;
|
||||
# take coord_resolution from the latent dict if the stage node set it.
|
||||
coord_resolution = samples.get("coord_resolution")
|
||||
if coord_resolution is not None:
|
||||
resolution = int(coord_resolution) * 16
|
||||
else:
|
||||
resolution = int(vae.first_stage_model.resolution.item())
|
||||
sample_tensor = samples["samples"]
|
||||
device = comfy.model_management.get_torch_device()
|
||||
coords = samples["coords"]
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
trellis_vae = vae.first_stage_model
|
||||
coord_counts = samples.get("coord_counts")
|
||||
pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None
|
||||
|
||||
samples = samples["samples"]
|
||||
if coord_counts is None:
|
||||
@ -236,10 +205,6 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
|
||||
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
|
||||
|
||||
if pixal3d_mode:
|
||||
for m in mesh:
|
||||
m.vertices = _pixal3d_unrotate_vertices(m.vertices)
|
||||
|
||||
face_list = [m.faces for m in mesh]
|
||||
vert_list = [m.vertices for m in mesh]
|
||||
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
|
||||
@ -276,7 +241,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
trellis_vae = vae.first_stage_model
|
||||
coord_counts = samples.get("coord_counts")
|
||||
pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None
|
||||
|
||||
samples = samples["samples"]
|
||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||
@ -288,7 +252,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
|
||||
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
|
||||
color_feats = voxel.feats[:, :3]
|
||||
voxel_coords = voxel.coords#[:, 1:]
|
||||
voxel_coords = voxel.coords
|
||||
|
||||
if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
||||
spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords
|
||||
@ -297,9 +261,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
else:
|
||||
tex_resolution = 1024
|
||||
|
||||
if pixal3d_mode:
|
||||
voxel_coords = _pixal3d_unrotate_sparse_coords(voxel_coords, resolution=tex_resolution)
|
||||
|
||||
voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution)
|
||||
return IO.NodeOutput(voxel)
|
||||
|
||||
@ -312,7 +273,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("resolution", options=["32", "64"], default="32")
|
||||
IO.Combo.Input("resolution", options=["32", "64"], default="32"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Voxel.Output("voxel"),
|
||||
@ -338,129 +299,132 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
ratio = current_res // resolution
|
||||
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
|
||||
voxel_data = decoded.squeeze(1).float()
|
||||
if samples.get("model_options", {}).get("proj_feat_pack") is not None:
|
||||
voxel_data = _pixal3d_unrotate_voxel_data(voxel_data)
|
||||
out = Types.VOXEL(voxel_data)
|
||||
return IO.NodeOutput(out)
|
||||
return IO.NodeOutput(Types.VOXEL(voxel_data))
|
||||
|
||||
class Trellis2UpsampleCascade(IO.ComfyNode):
|
||||
class Trellis2UpsampleStage(IO.ComfyNode):
|
||||
"""Cascade-upsamples a 512-resolution shape latent into high-resolution
|
||||
sparse coords and sets up the second shape-stage sampling pass at the
|
||||
target resolution, attaching per-stage metadata to the conditioning for
|
||||
the model to consume via extra_conds. target_resolution is reduced in
|
||||
128-step decrements until the unique upsampled coord count fits under
|
||||
max_tokens (floor 1024)."""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Trellis2UpsampleCascade",
|
||||
node_id="Trellis2UpsampleStage",
|
||||
category="latent/3d",
|
||||
display_name="Trellis2 Upsample Cascade",
|
||||
description="Upsamples low-resolution Trellis2 shape latents into higher resolution coordinates while respecting the maximum token budget.",
|
||||
display_name="Trellis2 Upsample Stage",
|
||||
inputs=[
|
||||
IO.Latent.Input("shape_latent"),
|
||||
IO.Conditioning.Input("positive"),
|
||||
IO.Conditioning.Input("negative"),
|
||||
IO.Latent.Input("shape_latent", tooltip="The 512-resolution shape latent output from the first shape-stage KSampler."),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024", tooltip="Controls output detail level for upsampling."),
|
||||
IO.Int.Input("max_tokens", default=49152, min=1024, max=100000,
|
||||
tooltip=(
|
||||
"Maximum number of output elements (coordinates) allowed after upsampling. "
|
||||
"Used to limit memory usage and control mesh density."
|
||||
))
|
||||
)),
|
||||
],
|
||||
outputs=[
|
||||
IO.Voxel.Output(
|
||||
"high_res_voxel",
|
||||
tooltip=(
|
||||
"High-resolution sparse coordinates produced after cascade upsampling. "
|
||||
"Represents the refined 3D structure at target resolution."
|
||||
)
|
||||
)
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, shape_latent, vae, target_resolution, max_tokens):
|
||||
shape_latent_512 = shape_latent
|
||||
device = comfy.model_management.get_torch_device()
|
||||
prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape)
|
||||
@staticmethod
|
||||
def _quantize_unique(hr_coords: torch.Tensor, lr_resolution: int, hr_resolution: int) -> torch.Tensor:
|
||||
# Fold the two scalar divisions into one and chain the float math in-place
|
||||
# to avoid 3 full M*3 fp32 transients per call.
|
||||
scale = (hr_resolution // 16) / lr_resolution
|
||||
spatial = hr_coords[:, 1:].float()
|
||||
spatial.add_(0.5).mul_(scale)
|
||||
quant = torch.cat([hr_coords[:, :1], spatial.int()], dim=1)
|
||||
return quant.unique(dim=0)
|
||||
|
||||
coord_counts = shape_latent_512.get("coord_counts")
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, shape_latent, vae, target_resolution, max_tokens):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
prepare_trellis_vae_for_decode(vae, shape_latent["samples"].shape)
|
||||
|
||||
coord_counts = shape_latent.get("coord_counts")
|
||||
decoder = vae.first_stage_model.shape_dec
|
||||
lr_resolution = 512
|
||||
target_resolution = int(target_resolution)
|
||||
|
||||
if coord_counts is None:
|
||||
feats, coords_512 = flatten_batched_sparse_latent(
|
||||
shape_latent_512["samples"],
|
||||
shape_latent_512["coords"],
|
||||
coord_counts,
|
||||
)
|
||||
feats = feats.to(device)
|
||||
coords_512 = coords_512.to(device)
|
||||
slat = shape_norm(feats, coords_512)
|
||||
slat.feats = slat.feats.to(next(decoder.parameters()).dtype)
|
||||
hr_coords = decoder.upsample(slat, upsample_times=4)
|
||||
|
||||
hr_resolution = target_resolution
|
||||
while True:
|
||||
quant_coords = torch.cat([
|
||||
hr_coords[:, :1],
|
||||
((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
|
||||
], dim=1)
|
||||
final_coords = quant_coords.unique(dim=0)
|
||||
num_tokens = final_coords.shape[0]
|
||||
|
||||
if num_tokens < max_tokens or hr_resolution <= 1024:
|
||||
break
|
||||
hr_resolution -= 128
|
||||
|
||||
return IO.NodeOutput(final_coords,)
|
||||
|
||||
items = split_batched_sparse_latent(
|
||||
shape_latent_512["samples"],
|
||||
shape_latent_512["coords"],
|
||||
coord_counts,
|
||||
)
|
||||
decoder_dtype = next(decoder.parameters()).dtype
|
||||
|
||||
sample_hr_coords = []
|
||||
for feats_i, coords_i in items:
|
||||
feats_i = feats_i.to(device)
|
||||
coords_i = coords_i.to(device).clone()
|
||||
coords_i[:, 0] = 0
|
||||
slat_i = shape_norm(feats_i, coords_i)
|
||||
slat_i.feats = slat_i.feats.to(decoder_dtype)
|
||||
sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4))
|
||||
# Decode each sample's HR coords, then search for the largest hr_resolution
|
||||
# that fits under max_tokens across all samples.
|
||||
if coord_counts is None:
|
||||
feats, coords_512 = flatten_batched_sparse_latent(
|
||||
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
||||
)
|
||||
slat = shape_norm(feats.to(device), coords_512.to(device))
|
||||
slat.feats = slat.feats.to(decoder_dtype)
|
||||
sample_hr_coords = [decoder.upsample(slat, upsample_times=4)]
|
||||
else:
|
||||
items = split_batched_sparse_latent(
|
||||
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
||||
)
|
||||
sample_hr_coords = []
|
||||
for feats_i, coords_i in items:
|
||||
coords_i = coords_i.to(device).clone()
|
||||
coords_i[:, 0] = 0
|
||||
slat_i = shape_norm(feats_i.to(device), coords_i)
|
||||
slat_i.feats = slat_i.feats.to(decoder_dtype)
|
||||
sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4))
|
||||
|
||||
# Resolution search — cache the final iteration's quantized unique tensors
|
||||
# so we don't recompute .unique() per sample after picking hr_resolution.
|
||||
hr_resolution = target_resolution
|
||||
quant_unique_list = []
|
||||
while True:
|
||||
quant_unique_list = []
|
||||
exceeds_limit = False
|
||||
for hr_coords_i in sample_hr_coords:
|
||||
quant_coords_i = torch.cat([
|
||||
hr_coords_i[:, :1],
|
||||
((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
|
||||
], dim=1)
|
||||
if quant_coords_i.unique(dim=0).shape[0] >= max_tokens:
|
||||
qu = cls._quantize_unique(hr_coords_i, lr_resolution, hr_resolution)
|
||||
quant_unique_list.append(qu)
|
||||
if qu.shape[0] >= max_tokens:
|
||||
exceeds_limit = True
|
||||
break
|
||||
if not exceeds_limit or hr_resolution <= 1024:
|
||||
if not exceeds_limit:
|
||||
break
|
||||
if hr_resolution <= 1024:
|
||||
for k in range(len(quant_unique_list), len(sample_hr_coords)):
|
||||
quant_unique_list.append(
|
||||
cls._quantize_unique(sample_hr_coords[k], lr_resolution, hr_resolution)
|
||||
)
|
||||
break
|
||||
hr_resolution -= 128
|
||||
|
||||
final_coords_list = []
|
||||
output_coord_counts = []
|
||||
for sample_offset, hr_coords_i in enumerate(sample_hr_coords):
|
||||
quant_coords_i = torch.cat([
|
||||
hr_coords_i[:, :1],
|
||||
((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
|
||||
], dim=1)
|
||||
final_coords_i = quant_coords_i.unique(dim=0)
|
||||
final_coords_i = final_coords_i.clone()
|
||||
final_coords_i[:, 0] = sample_offset
|
||||
final_coords_list.append(final_coords_i)
|
||||
output_coord_counts.append(int(final_coords_i.shape[0]))
|
||||
# Rewrite batch column to match per-sample offset and concat.
|
||||
per_sample_counts = []
|
||||
for sample_offset, qu in enumerate(quant_unique_list):
|
||||
qu[:, 0] = sample_offset
|
||||
per_sample_counts.append(int(qu.shape[0]))
|
||||
coords = torch.cat(quant_unique_list, dim=0)
|
||||
counts = torch.tensor(per_sample_counts, dtype=torch.int64)
|
||||
coord_resolution = hr_resolution // 16
|
||||
|
||||
coords = torch.cat(final_coords_list, dim=0)
|
||||
output = Types.VOXEL(coords)
|
||||
output.coord_counts = torch.tensor(output_coord_counts, dtype=torch.int64)
|
||||
output.resolutions = torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64)
|
||||
output.upsampled = True
|
||||
batch_size, _, max_tokens_out = infer_batched_coord_layout(coords)
|
||||
latent = torch.zeros(batch_size, 32, max_tokens_out, 1)
|
||||
|
||||
return IO.NodeOutput(output,)
|
||||
extras = {
|
||||
"trellis2_generation_mode": "shape_generation",
|
||||
"trellis2_coords": coords,
|
||||
"trellis2_coord_counts": counts,
|
||||
}
|
||||
proj_pack = _proj_pack_from_conditioning(positive)
|
||||
if proj_pack is not None:
|
||||
extras["trellis2_proj_feats"] = compute_stage_proj_feats(
|
||||
proj_pack, "shape_1024", coords=coords, coord_resolution=coord_resolution,
|
||||
)
|
||||
positive_out = _conditioning_set_extras(positive, extras)
|
||||
negative_out = _conditioning_set_extras(negative, extras)
|
||||
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
|
||||
"coord_resolution": coord_resolution, "type": "trellis2"}
|
||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||
|
||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||
@ -639,130 +603,152 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]]
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
def _proj_pack_from_conditioning(conditioning):
|
||||
"""Return the proj_feat_pack dict embedded in a Pixal3D conditioning (or None
|
||||
for vanilla Trellis2 / no conditioning connected). Pixal3DConditioning ships
|
||||
the pack in cond[0][1]["proj_feat_pack"]; Trellis2Conditioning doesn't set it."""
|
||||
if not conditioning:
|
||||
return None
|
||||
entry = conditioning[0]
|
||||
if not isinstance(entry, (list, tuple)) or len(entry) < 2 or not isinstance(entry[1], dict):
|
||||
return None
|
||||
return entry[1].get("proj_feat_pack")
|
||||
|
||||
|
||||
def _conditioning_set_extras(conditioning, extras: dict):
|
||||
"""Return a copy of `conditioning` with `extras` merged into each entry's
|
||||
dict — same shallow-copy pattern ControlNetApplyAdvanced uses. The dicts
|
||||
are copied so we don't mutate upstream conditioning."""
|
||||
out = []
|
||||
for entry in conditioning:
|
||||
if isinstance(entry, (list, tuple)) and len(entry) >= 2 and isinstance(entry[1], dict):
|
||||
new_dict = entry[1].copy()
|
||||
new_dict.update(extras)
|
||||
out.append([entry[0], new_dict])
|
||||
else:
|
||||
out.append(entry)
|
||||
return out
|
||||
|
||||
|
||||
class Trellis2ShapeStage(IO.ComfyNode):
|
||||
"""Sets up the first shape-stage sampling pass: extracts sparse coords from
|
||||
the dense structure voxel produced by VaeDecodeStructureTrellis2, builds an
|
||||
empty sparse latent, and attaches per-stage metadata to the conditioning so
|
||||
the model reads it via extra_conds at sample time. For the second shape pass
|
||||
(post-upsample), use Trellis2UpsampleStage instead — it combines the cascade
|
||||
and the second-pass stage setup."""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyTrellis2ShapeLatent",
|
||||
node_id="Trellis2ShapeStage",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Conditioning.Input("positive"),
|
||||
IO.Conditioning.Input("negative"),
|
||||
IO.Voxel.Input(
|
||||
"voxel",
|
||||
tooltip=(
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled voxel coordinates from a previous cascade stage."
|
||||
)
|
||||
),
|
||||
Pixal3DProjPack.Input(
|
||||
"proj_feat_pack",
|
||||
optional=True,
|
||||
tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.",
|
||||
tooltip="Dense structure voxel from VaeDecodeStructureTrellis2.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, voxel, proj_feat_pack=None):
|
||||
is_512_pass = False
|
||||
coord_resolution = None
|
||||
upsampled = hasattr(voxel, "upsampled")
|
||||
if upsampled:
|
||||
if hasattr(voxel, "resolutions") and voxel.resolutions is not None:
|
||||
coord_resolution = int(voxel.resolutions[0].item()) // 16
|
||||
voxel = voxel.data
|
||||
|
||||
if not upsampled:
|
||||
voxel_data = voxel.data
|
||||
if proj_feat_pack is not None:
|
||||
voxel_data = _pixal3d_rerotate_voxel_data(voxel_data)
|
||||
decoded = voxel_data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
is_512_pass = True
|
||||
coord_resolution = int(decoded.shape[-1])
|
||||
def execute(cls, positive, negative, voxel):
|
||||
decoded = voxel.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
coord_resolution = int(decoded.shape[-1])
|
||||
|
||||
# Dispatch based on the upstream voxel resolution, mirroring upstream's
|
||||
# pipeline_type → ss_res table:
|
||||
# coord_res == 32 → first cascade shape pass OR pure-512 pipeline
|
||||
# (img2shape_512 + shape_512 proj stage, 512 DINO).
|
||||
# coord_res > 32 → pure-1024 non-cascade pipeline
|
||||
# (img2shape + shape_1024 proj stage, 1024 DINO).
|
||||
if coord_resolution <= 32:
|
||||
mode = "shape_generation_512"
|
||||
stage = "shape_512"
|
||||
else:
|
||||
coords = voxel.int()
|
||||
mode = "shape_generation"
|
||||
stage = "shape_1024"
|
||||
|
||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
in_channels = 32
|
||||
# image like format
|
||||
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
|
||||
latent = torch.zeros(batch_size, 32, max_tokens, 1)
|
||||
|
||||
if is_512_pass:
|
||||
generation_mode = "shape_generation_512"
|
||||
else:
|
||||
generation_mode = "shape_generation"
|
||||
model_options = {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}
|
||||
if coord_resolution is not None:
|
||||
model_options["coord_resolution"] = coord_resolution
|
||||
if proj_feat_pack is not None:
|
||||
model_options["proj_feat_pack"] = proj_feat_pack
|
||||
return IO.NodeOutput({"samples": latent, "coords": coords, "coord_counts": counts, "type": "trellis2",
|
||||
"model_options": model_options})
|
||||
extras = {
|
||||
"trellis2_generation_mode": mode,
|
||||
"trellis2_coords": coords,
|
||||
"trellis2_coord_counts": counts,
|
||||
}
|
||||
proj_pack = _proj_pack_from_conditioning(positive)
|
||||
if proj_pack is not None:
|
||||
extras["trellis2_proj_feats"] = compute_stage_proj_feats(
|
||||
proj_pack, stage, coords=coords, coord_resolution=coord_resolution,
|
||||
)
|
||||
positive_out = _conditioning_set_extras(positive, extras)
|
||||
negative_out = _conditioning_set_extras(negative, extras)
|
||||
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
|
||||
"coord_resolution": coord_resolution, "type": "trellis2"}
|
||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||
|
||||
class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
class Trellis2TextureStage(IO.ComfyNode):
|
||||
"""Sets up the texture-stage sampling pass. Reads coords / coord_counts /
|
||||
coord_resolution and the shape_slat (the per-voxel shape latent) from the
|
||||
incoming shape_latent dict — set there by Trellis2ShapeStage or
|
||||
Trellis2UpsampleStage. Builds an empty sparse latent at the same coord
|
||||
layout and attaches per-stage metadata to the conditioning."""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyTrellis2LatentTexture",
|
||||
node_id="Trellis2TextureStage",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input(
|
||||
"voxel",
|
||||
tooltip=(
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled voxel coordinates from a previous cascade stage."
|
||||
)
|
||||
),
|
||||
IO.Conditioning.Input("positive"),
|
||||
IO.Conditioning.Input("negative"),
|
||||
IO.Latent.Input("shape_latent"),
|
||||
Pixal3DProjPack.Input(
|
||||
"proj_feat_pack",
|
||||
optional=True,
|
||||
tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, voxel, shape_latent, proj_feat_pack=None):
|
||||
def execute(cls, positive, negative, shape_latent):
|
||||
channels = 32
|
||||
coord_resolution = None
|
||||
upsampled = hasattr(voxel, "upsampled")
|
||||
if upsampled:
|
||||
if hasattr(voxel, "resolutions") and voxel.resolutions is not None:
|
||||
coord_resolution = int(voxel.resolutions[0].item()) // 16
|
||||
voxel = voxel.data
|
||||
|
||||
if not upsampled:
|
||||
voxel_data = voxel.data
|
||||
if proj_feat_pack is not None:
|
||||
voxel_data = _pixal3d_rerotate_voxel_data(voxel_data)
|
||||
decoded = voxel_data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
coord_resolution = int(decoded.shape[-1])
|
||||
else:
|
||||
coords = voxel.int()
|
||||
coords = shape_latent["coords"]
|
||||
coord_resolution = shape_latent.get("coord_resolution")
|
||||
|
||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
|
||||
shape_latent = shape_latent["samples"]
|
||||
if shape_latent.ndim == 4:
|
||||
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||
shape_slat = shape_latent["samples"]
|
||||
if shape_slat.ndim == 4:
|
||||
shape_slat = shape_slat.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||
|
||||
latent = torch.zeros(batch_size, channels, max_tokens, 1)
|
||||
model_options = {"generation_mode": "texture_generation", "coords": coords, "coord_counts": counts, "shape_slat": shape_latent}
|
||||
extras = {
|
||||
"trellis2_generation_mode": "texture_generation",
|
||||
"trellis2_coords": coords,
|
||||
"trellis2_coord_counts": counts,
|
||||
"trellis2_shape_slat": shape_slat,
|
||||
}
|
||||
proj_pack = _proj_pack_from_conditioning(positive)
|
||||
if proj_pack is not None and coord_resolution is not None:
|
||||
extras["trellis2_proj_feats"] = compute_stage_proj_feats(
|
||||
proj_pack, "tex_1024", coords=coords, coord_resolution=coord_resolution,
|
||||
)
|
||||
|
||||
positive_out = _conditioning_set_extras(positive, extras)
|
||||
negative_out = _conditioning_set_extras(negative, extras)
|
||||
out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts}
|
||||
if coord_resolution is not None:
|
||||
model_options["coord_resolution"] = coord_resolution
|
||||
if proj_feat_pack is not None:
|
||||
model_options["proj_feat_pack"] = proj_feat_pack
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
|
||||
"model_options": model_options})
|
||||
out_latent["coord_resolution"] = coord_resolution
|
||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||
|
||||
|
||||
class EmptyTrellis2LatentStructure(IO.ComfyNode):
|
||||
@ -773,30 +759,17 @@ class EmptyTrellis2LatentStructure(IO.ComfyNode):
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||
Pixal3DProjPack.Input(
|
||||
"proj_feat_pack",
|
||||
optional=True,
|
||||
tooltip="Pixal3D pixel-aligned projection pack. Leave empty for vanilla Trellis2.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, batch_size, proj_feat_pack=None):
|
||||
# Trellis2.forward slices x[:, :8] and pads out to 32; KSampler residual math
|
||||
# needs the empty latent to match latent_format (32-channel).
|
||||
def execute(cls, batch_size):
|
||||
in_channels = 32
|
||||
resolution = 16
|
||||
latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution)
|
||||
output = {
|
||||
"samples": latent,
|
||||
"type": "trellis2",
|
||||
}
|
||||
if proj_feat_pack is not None:
|
||||
output["model_options"] = {"proj_feat_pack": proj_feat_pack}
|
||||
return IO.NodeOutput(output)
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
|
||||
h_p = w_p = image_size // patch_size
|
||||
@ -813,20 +786,11 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
|
||||
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
|
||||
|
||||
|
||||
def _fov_from_moge_intrinsics(moge_intrinsics: torch.Tensor) -> float:
|
||||
fx = moge_intrinsics[..., 0, 0].float()
|
||||
fov = 2.0 * torch.atan(0.5 / fx.clamp(min=1e-4))
|
||||
return float(fov.mean().item())
|
||||
|
||||
|
||||
def _run_dinov3_with_patches(model, cropped_pil, image_size):
|
||||
# Pixal3D's cross-attn was trained against CLS + registers only (~5 tokens), not the
|
||||
# full patch grid. The patch grid goes to the proj branch via patches_2d.
|
||||
def _run_dinov3_with_patches(model, composite, image_size):
|
||||
model_internal = model.model
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
resized = cropped_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
||||
img_np = np.array(resized).astype(np.float32) / 255.0
|
||||
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
||||
img_t = comfy.utils.common_upscale(composite, image_size, image_size, "lanczos", "disabled")
|
||||
img_t = img_t.to(torch_device)
|
||||
img_t = (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||
model_internal.image_size = image_size
|
||||
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
|
||||
@ -838,48 +802,59 @@ def _run_dinov3_with_patches(model, cropped_pil, image_size):
|
||||
|
||||
|
||||
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
||||
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
pil_img = Image.fromarray(img_np)
|
||||
pil_mask = Image.fromarray(mask_np)
|
||||
max_size = max(pil_img.size)
|
||||
scale = min(1.0, max_image_size / max_size)
|
||||
if scale < 1.0:
|
||||
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
|
||||
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
|
||||
scene_size = (pil_img.width, pil_img.height)
|
||||
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
|
||||
rgba_np[:, :, :3] = np.array(pil_img)
|
||||
rgba_np[:, :, 3] = np.array(pil_mask)
|
||||
alpha = rgba_np[:, :, 3]
|
||||
bbox_coords = np.argwhere(alpha > 0.8 * 255)
|
||||
if len(bbox_coords) > 0:
|
||||
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
|
||||
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
|
||||
img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float()
|
||||
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float()
|
||||
# Upstream went float→PIL uint8 implicitly; match that to keep composite bit-exact.
|
||||
img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
|
||||
mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
|
||||
|
||||
H, W = img.shape[-2:]
|
||||
if max(H, W) > max_image_size:
|
||||
scale = max_image_size / max(H, W)
|
||||
new_w, new_h = int(W * scale), int(H * scale)
|
||||
img = comfy.utils.common_upscale(img, new_w, new_h, "lanczos", "disabled")
|
||||
mask = comfy.utils.common_upscale(mask, new_w, new_h, "nearest-exact", "disabled")
|
||||
H, W = new_h, new_w
|
||||
scene_size = (W, H)
|
||||
|
||||
alpha_u8 = (mask[0, 0].clamp(0, 1) * 255.0).to(torch.uint8)
|
||||
fg_pixels = (alpha_u8 > 204).nonzero()
|
||||
if fg_pixels.numel() > 0:
|
||||
y_min, x_min = fg_pixels.min(dim=0).values.tolist()
|
||||
y_max, x_max = fg_pixels.max(dim=0).values.tolist()
|
||||
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
||||
# Upstream pads the bbox by 10% — encoders were trained with that breathing room.
|
||||
size = max(y_max - y_min, x_max - x_min)
|
||||
size = int(size * 1.1)
|
||||
size = int(max(y_max - y_min, x_max - x_min) * 1.1)
|
||||
half = size // 2
|
||||
crop_x1 = int(center_x - half)
|
||||
crop_y1 = int(center_y - half)
|
||||
crop_x2 = crop_x1 + 2 * half
|
||||
crop_y2 = crop_y1 + 2 * half
|
||||
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
|
||||
rgba_pil = Image.fromarray(rgba_np)
|
||||
cropped_rgba = rgba_pil.crop(crop_bbox)
|
||||
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
||||
else:
|
||||
logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.")
|
||||
cropped_np = rgba_np.astype(np.float32) / 255.0
|
||||
crop_bbox = (0, 0, scene_size[0], scene_size[1])
|
||||
fg = cropped_np[:, :, :3]
|
||||
alpha_float = cropped_np[:, :, 3:4]
|
||||
composite_np = fg * alpha_float
|
||||
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||
return Image.fromarray(composite_uint8), crop_bbox, scene_size
|
||||
crop_x1, crop_y1, crop_x2, crop_y2 = 0, 0, W, H
|
||||
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
|
||||
|
||||
# Zero-pad out-of-bounds slice (PIL.crop semantics).
|
||||
pad_l = max(0, -crop_x1)
|
||||
pad_t = max(0, -crop_y1)
|
||||
pad_r = max(0, crop_x2 - W)
|
||||
pad_b = max(0, crop_y2 - H)
|
||||
if pad_l or pad_t or pad_r or pad_b:
|
||||
img = torch.nn.functional.pad(img, (pad_l, pad_r, pad_t, pad_b), value=0.0)
|
||||
mask = torch.nn.functional.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=0.0)
|
||||
crop_x1 += pad_l; crop_x2 += pad_l
|
||||
crop_y1 += pad_t; crop_y2 += pad_t
|
||||
cropped_img = img [..., crop_y1:crop_y2, crop_x1:crop_x2]
|
||||
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
|
||||
|
||||
composite = (cropped_img * cropped_mask).clamp(0, 1)
|
||||
composite = (composite * 255.0).round().clamp(0, 255).to(torch.uint8).float() / 255.0
|
||||
return composite, crop_bbox, scene_size
|
||||
|
||||
def _fov_from_moge_intrinsics(moge_intrinsics: torch.Tensor) -> float:
|
||||
fx = moge_intrinsics[..., 0, 0].float()
|
||||
fov = 2.0 * torch.atan(0.5 / fx.clamp(min=1e-4))
|
||||
return float(fov.mean().item())
|
||||
|
||||
class Pixal3DConditioning(IO.ComfyNode):
|
||||
|
||||
@ -901,10 +876,6 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
"mesh_scale", default=1.0, min=0.1, max=4.0, step=0.01,
|
||||
tooltip="Mesh scale; 1.0 means unit cube.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"distance_override", default=0.0, min=0.0, max=10.0, step=0.001,
|
||||
tooltip="Override camera distance directly. 0 = auto-derive from FOV.",
|
||||
),
|
||||
io.Custom("MOGE_GEOMETRY").Input(
|
||||
"moge_geometry",
|
||||
optional=True,
|
||||
@ -920,13 +891,11 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
Pixal3DProjPack.Output(display_name="proj_feat_pack"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_model, image, mask, camera_angle_x, mesh_scale,
|
||||
distance_override=0.0,
|
||||
moge_geometry=None, naf_model=None) -> IO.NodeOutput:
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
@ -945,21 +914,21 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
|
||||
cond_512_list, cond_1024_list = [], []
|
||||
patches_512_list, patches_1024_list = [], []
|
||||
cropped_pil_list = []
|
||||
composite_list = []
|
||||
crop_bbox_list, scene_size_list = [], []
|
||||
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
for b in range(batch_size):
|
||||
item_image = image[b]
|
||||
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
|
||||
cropped_pil, crop_bbox, scene_size = _crop_image_with_mask(
|
||||
composite, crop_bbox, scene_size = _crop_image_with_mask(
|
||||
item_image, item_mask, max_image_size=1024)
|
||||
crop_bbox_list.append(crop_bbox)
|
||||
scene_size_list.append(scene_size)
|
||||
cropped_pil_list.append(cropped_pil)
|
||||
composite_list.append(composite)
|
||||
|
||||
cond_512 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 512)
|
||||
cond_1024 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 1024)
|
||||
cond_512 = _run_dinov3_with_patches(clip_vision_model, composite, 512)
|
||||
cond_1024 = _run_dinov3_with_patches(clip_vision_model, composite, 1024)
|
||||
cond_512_list.append(cond_512["tokens"].to(device))
|
||||
cond_1024_list.append(cond_1024["tokens"].to(device))
|
||||
patches_512_list.append(cond_512["patches_2d"].to(device))
|
||||
@ -971,42 +940,28 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
fm_512_dino = torch.cat(patches_512_list, dim=0)
|
||||
fm_1024_dino = torch.cat(patches_1024_list, dim=0)
|
||||
|
||||
# Upstream samples the LR DINO grid AND the NAF HR grid separately at projected
|
||||
# 3D points, then cats sampled features along channels. Back-projection (in model.py)
|
||||
# mirrors that — here we just stash LR + optional HR per stage.
|
||||
# The LR DINO grid AND the NAF HR grid are sampled separately
|
||||
# NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024.
|
||||
def _naf_hr(lr_feat, image_pil_list, image_size, naf_target):
|
||||
def _naf_hr(lr_feat, composites, image_size, naf_target):
|
||||
if naf_model is None or naf_target is None:
|
||||
return None
|
||||
# Run NAF in the input feature dtype (typically fp16 since DINO/ClipVision
|
||||
# loads that way). The previous .float() cast doubled NAF memory by forcing
|
||||
# full fp32 — at tex_1024/target=1024 that's ~10 GB on its own. Model
|
||||
# weights need to match input dtype since PyTorch conv ops error out on
|
||||
# mixed fp16-input/fp32-weight.
|
||||
target_dtype = lr_feat.dtype
|
||||
if next(naf_model.parameters()).dtype != target_dtype:
|
||||
naf_model.to(dtype=target_dtype)
|
||||
imgs = torch.stack([
|
||||
torch.from_numpy(
|
||||
np.array(p.resize((image_size, image_size), Image.Resampling.LANCZOS))
|
||||
.astype(np.float32) / 255.0
|
||||
).permute(2, 0, 1)
|
||||
for p in image_pil_list
|
||||
imgs = torch.cat([
|
||||
comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")
|
||||
for c in composites
|
||||
], dim=0).to(torch_device).to(target_dtype)
|
||||
|
||||
hr = naf_model(imgs, lr_feat.to(torch_device).to(target_dtype), naf_target)
|
||||
return hr.to(device)
|
||||
|
||||
hr_shape_512 = _naf_hr(fm_512_dino, cropped_pil_list, 512, (512, 512))
|
||||
hr_shape_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (512, 512))
|
||||
hr_tex_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (1024, 1024))
|
||||
hr_shape_512 = _naf_hr(fm_512_dino, composite_list, 512, (512, 512))
|
||||
hr_shape_1024 = _naf_hr(fm_1024_dino, composite_list, 1024, (512, 512))
|
||||
hr_tex_1024 = _naf_hr(fm_1024_dino, composite_list, 1024, (1024, 1024))
|
||||
|
||||
# distance_from_fov: grid_point (-1, 0, 0) projects to pixel (0, image_resolution-1).
|
||||
camera_angle_x = float(camera_angle_x)
|
||||
if distance_override > 0:
|
||||
distance = float(distance_override)
|
||||
else:
|
||||
distance = 0.5 / math.tan(camera_angle_x / 2.0) / float(mesh_scale)
|
||||
distance = 0.5 / math.tan(camera_angle_x / 2.0) / float(mesh_scale)
|
||||
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
|
||||
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
|
||||
scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32)
|
||||
@ -1028,13 +983,33 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
"scene_sizes": scene_size_list,
|
||||
}
|
||||
|
||||
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024
|
||||
# (Trellis2.forward swaps context↔embeds for non-structure HR stages).
|
||||
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
|
||||
# proj_feat_pack rides in the conditioning dict (same place embeds, ControlNet
|
||||
# hints etc. live); the sampler auto-promotes it to a model.forward kwarg via
|
||||
# Trellis2.extra_conds. The same pack object is shared between pos/neg —
|
||||
# CONDConstant.can_concat sees them equal and concats to a single dict, then
|
||||
# Trellis2.forward zeros proj for the uncond slots via cond_or_uncond.
|
||||
# Pre-compute the SS-stage proj features (dense 16³ grid) once here — the
|
||||
# shape/texture stages do their own computes in their respective stage nodes.
|
||||
# proj_pack lives on intermediate (CPU); force the compute onto cuda so
|
||||
# the bilinear-sampling step doesn't run on CPU.
|
||||
ss_proj_feats = compute_stage_proj_feats(
|
||||
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
|
||||
device=torch_device,
|
||||
)
|
||||
neg_global = torch.zeros_like(global_512)
|
||||
neg_embeds = torch.zeros_like(global_1024)
|
||||
positive = [[global_512, {"embeds": global_1024}]]
|
||||
negative = [[neg_global, {"embeds": neg_embeds}]]
|
||||
return IO.NodeOutput(positive, negative, proj_pack)
|
||||
base_extras = {
|
||||
"embeds": global_1024, "proj_feat_pack": proj_pack,
|
||||
"trellis2_proj_feats": ss_proj_feats,
|
||||
}
|
||||
neg_extras = {
|
||||
"embeds": neg_embeds, "proj_feat_pack": proj_pack,
|
||||
"trellis2_proj_feats": ss_proj_feats,
|
||||
}
|
||||
positive = [[global_512, base_extras]]
|
||||
negative = [[neg_global, neg_extras]]
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
|
||||
def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution):
|
||||
@ -1069,7 +1044,7 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
Pixal3DProjPack.Input("proj_feat_pack", tooltip="The proj pack produced by Pixal3DConditioning for this object."),
|
||||
IO.Conditioning.Input("positive", tooltip="The positive conditioning from Pixal3DConditioning for this object — Pixal3DAlignObject reads transform_matrix / camera_angle_x / mesh_scale / crop_bboxes out of its proj_feat_pack."),
|
||||
io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."),
|
||||
IO.Mask.Input(
|
||||
"object_mask",
|
||||
@ -1089,7 +1064,10 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, proj_feat_pack, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput:
|
||||
def execute(cls, mesh, positive, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput:
|
||||
proj_feat_pack = _proj_pack_from_conditioning(positive)
|
||||
if proj_feat_pack is None:
|
||||
raise ValueError("Pixal3DAlignObject: positive conditioning has no proj_feat_pack — connect a Pixal3DConditioning output.")
|
||||
vertices = mesh.vertices
|
||||
faces = mesh.faces
|
||||
if vertices.ndim == 3:
|
||||
@ -1117,10 +1095,11 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
||||
"image. Run MoGe on the same resized scene image Pixal3DConditioning used."
|
||||
)
|
||||
|
||||
# Compose VaeDecodeShapeTrellis's R_y(180°) inverse with R_proj to map user mesh
|
||||
# space to ProjGrid world: (X, Y, Z) -> (-X, Z, Y).
|
||||
# Vertices come out of VaeDecodeShapeTrellis in the Pixal3D model frame
|
||||
# (no un-rotation). Apply _PROJ_GRID_ROTATION = R_x(-90°) to map model
|
||||
# frame → ProjGrid world: (X, Y, Z) -> (X, -Z, Y).
|
||||
v = vertices_one.float()
|
||||
verts_world = torch.stack([-v[..., 0], v[..., 2], v[..., 1]], dim=-1)
|
||||
verts_world = torch.stack([v[..., 0], -v[..., 2], v[..., 1]], dim=-1)
|
||||
verts_world = verts_world / float(mesh_scale.item())
|
||||
uv_crop, _depth, valid = _project_vertices_to_image_uv(
|
||||
verts_world, T[0], cam_angle[0], image_resolution)
|
||||
@ -1200,6 +1179,74 @@ class LoadNAFModel(IO.ComfyNode):
|
||||
return IO.NodeOutput(model)
|
||||
|
||||
|
||||
class CFGGuidanceInterval(IO.ComfyNode):
|
||||
"""Generic model patch: apply CFG only during [start_percent, end_percent] of
|
||||
the sampling schedule. Outside that window, skip the uncond computation and
|
||||
collapse to effective cfg=1 — same idea as upstream Trellis2 / Pixal3D's
|
||||
guidance_interval_mixin, but lives at the sampler level (via
|
||||
sampler_calc_cond_batch_function) so it works for any model.
|
||||
|
||||
Percents use ComfyUI's standard convention: 0.0 = start of sampling
|
||||
(max-noise step), 1.0 = end of sampling (clean step). Conversion to sigma
|
||||
is done via model_sampling.percent_to_sigma so the window is portable
|
||||
across schedules (flow / EDM / discrete) and shift settings.
|
||||
|
||||
Defaults are full-range (no bypass). For Trellis2's upstream behavior,
|
||||
wire (start_percent=0.0, end_percent=0.667) on the SS / shape KSamplers;
|
||||
texture defaults to cfg=1 so the node is moot there."""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="CFGGuidanceInterval",
|
||||
category="model_patches/sampling",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001,
|
||||
tooltip="Fraction of sampling at which CFG turns ON (0 = beginning)."),
|
||||
IO.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001,
|
||||
tooltip="Fraction of sampling at which CFG turns OFF (1 = end)."),
|
||||
],
|
||||
outputs=[IO.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, start_percent, end_percent):
|
||||
import comfy.samplers
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
# percent_to_sigma is monotonically decreasing: percent=0 -> sigma_max,
|
||||
# percent=1 -> sigma_min. So start_percent < end_percent in user space
|
||||
# means sigma_start > sigma_end. "Inside the window" is sigma in
|
||||
# [sigma_end, sigma_start].
|
||||
sigma_start = float(model_sampling.percent_to_sigma(start_percent))
|
||||
sigma_end = float(model_sampling.percent_to_sigma(end_percent))
|
||||
|
||||
def calc_cond_batch_with_interval(args):
|
||||
sigma_val = args["sigma"][0].item()
|
||||
conds = args["conds"]
|
||||
input_x = args["input"]
|
||||
timestep = args["sigma"]
|
||||
model_ref = args["model"]
|
||||
model_opts = args["model_options"]
|
||||
|
||||
# conds is typically [cond, uncond]; uncond may be None when ComfyUI's
|
||||
# global cfg=1 optimization has already pruned it.
|
||||
cond = conds[0]
|
||||
uncond = conds[1] if len(conds) > 1 else None
|
||||
inside = sigma_end <= sigma_val <= sigma_start
|
||||
|
||||
if uncond is None or inside:
|
||||
return comfy.samplers.calc_cond_batch(model_ref, conds, input_x, timestep, model_opts)
|
||||
# Outside the window: compute cond only, mirror it into the uncond slot
|
||||
# so the downstream cfg_function collapses to `cond` (effective cfg=1).
|
||||
out = comfy.samplers.calc_cond_batch(model_ref, [cond], input_x, timestep, model_opts)
|
||||
return [out[0], out[0]]
|
||||
|
||||
m = model.clone()
|
||||
m.model_options["sampler_calc_cond_batch_function"] = calc_cond_batch_with_interval
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
|
||||
class Trellis2Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1208,13 +1255,14 @@ class Trellis2Extension(ComfyExtension):
|
||||
Pixal3DConditioning,
|
||||
Pixal3DAlignObject,
|
||||
LoadNAFModel,
|
||||
EmptyTrellis2ShapeLatent,
|
||||
Trellis2ShapeStage,
|
||||
EmptyTrellis2LatentStructure,
|
||||
EmptyTrellis2LatentTexture,
|
||||
Trellis2TextureStage,
|
||||
VaeDecodeTextureTrellis,
|
||||
VaeDecodeShapeTrellis,
|
||||
VaeDecodeStructureTrellis2,
|
||||
Trellis2UpsampleCascade,
|
||||
Trellis2UpsampleStage,
|
||||
CFGGuidanceInterval,
|
||||
]
|
||||
|
||||
|
||||
|
||||
4
nodes.py
4
nodes.py
@ -1537,10 +1537,6 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
if "model_options" in latent:
|
||||
inner = model.model.diffusion_model
|
||||
inner.meta = latent["model_options"]
|
||||
|
||||
callback = latent_preview.prepare_callback(model, steps)
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user