mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Cleanup
This commit is contained in:
parent
7bd1fa6e78
commit
288f3cf134
@ -770,9 +770,10 @@ class Hunyuan3Dv2_1(LatentFormat):
|
|||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|
||||||
class Trellis2(LatentFormat): # TODO
|
class Trellis2(LatentFormat):
|
||||||
latent_channels = 32
|
latent_channels = 32
|
||||||
trellis3d_preview = True # routes the sampler preview to Trellis3DPreviewer
|
|
||||||
|
|
||||||
class Hunyuan3Dv2mini(LatentFormat):
|
class Hunyuan3Dv2mini(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|||||||
@ -32,11 +32,6 @@ except ImportError as e:
|
|||||||
raise e
|
raise e
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
try:
|
|
||||||
from sageattention import sageattn_varlen # sageattention >= 2
|
|
||||||
except ImportError:
|
|
||||||
sageattn_varlen = None
|
|
||||||
|
|
||||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from sageattn3 import sageattn3_blackwell
|
from sageattn3 import sageattn3_blackwell
|
||||||
@ -53,24 +48,6 @@ except ImportError:
|
|||||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.nn.attention.varlen import varlen_attn as _torch_varlen_attn
|
|
||||||
except ImportError:
|
|
||||||
_torch_varlen_attn = None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_varlen(kwargs):
|
|
||||||
"""Varlen mode is opted into by passing cu_seqlens_q in kwargs."""
|
|
||||||
return kwargs.get("cu_seqlens_q") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def _varlen_args(kwargs):
|
|
||||||
cu_seqlens_q = kwargs["cu_seqlens_q"]
|
|
||||||
cu_seqlens_kv = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
|
||||||
max_seqlen_q = int(kwargs["max_seqlen_q"])
|
|
||||||
max_seqlen_kv = int(kwargs.get("max_seqlen_kv", max_seqlen_q))
|
|
||||||
return cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
|
|
||||||
|
|
||||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||||
def register_attention_function(name: str, func: Callable):
|
def register_attention_function(name: str, func: Callable):
|
||||||
# avoid replacing existing functions
|
# avoid replacing existing functions
|
||||||
@ -167,8 +144,6 @@ def wrap_attn(func):
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if _is_varlen(kwargs):
|
|
||||||
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
|
|
||||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -243,8 +218,6 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if _is_varlen(kwargs):
|
|
||||||
return attention_pytorch(query, key, value, heads, mask=mask, **kwargs)
|
|
||||||
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -320,8 +293,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if _is_varlen(kwargs):
|
|
||||||
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
|
|
||||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -453,17 +424,6 @@ except:
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if _is_varlen(kwargs):
|
|
||||||
# q, k, v expected as packed [T_total, H, C]. xformers wants per-item
|
|
||||||
# seqlen lists and a leading batch dim of 1.
|
|
||||||
cu_seqlens_q, cu_seqlens_kv, _max_q, _max_kv = _varlen_args(kwargs)
|
|
||||||
q_seqlen = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist()
|
|
||||||
kv_seqlen = (cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]).tolist()
|
|
||||||
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
|
|
||||||
return xformers.ops.memory_efficient_attention(
|
|
||||||
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_bias=attn_bias,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
b = q.shape[0]
|
b = q.shape[0]
|
||||||
dim_head = q.shape[-1]
|
dim_head = q.shape[-1]
|
||||||
# check to make sure xformers isn't broken
|
# check to make sure xformers isn't broken
|
||||||
@ -533,22 +493,6 @@ else:
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if _is_varlen(kwargs):
|
|
||||||
# q, k, v expected as packed [T_total, H, C]. cu_seqlens_q / cu_seqlens_kv
|
|
||||||
# describe per-item offsets. Native varlen kernel if available, else
|
|
||||||
# nested-tensor SDPA fallback. mask/attn_precision/etc are ignored here.
|
|
||||||
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
|
|
||||||
if _torch_varlen_attn is not None:
|
|
||||||
return _torch_varlen_attn(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
|
|
||||||
q_nj = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
|
||||||
k_nj = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_kv.long())
|
|
||||||
v_nj = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_kv.long())
|
|
||||||
out = comfy.ops.scaled_dot_product_attention(
|
|
||||||
q_nj.transpose(1, 2), k_nj.transpose(1, 2), v_nj.transpose(1, 2),
|
|
||||||
attn_mask=None, dropout_p=0.0, is_causal=False,
|
|
||||||
)
|
|
||||||
return out.transpose(1, 2).values()
|
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
@ -597,13 +541,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if _is_varlen(kwargs):
|
|
||||||
# q, k, v expected as packed [T_total, H, C].
|
|
||||||
if sageattn_varlen is None:
|
|
||||||
# sageattention v1 has no varlen kernel; fall back to attention_pytorch's varlen path.
|
|
||||||
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
|
|
||||||
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
|
|
||||||
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
|
|
||||||
if kwargs.get("low_precision_attention", True) is False:
|
if kwargs.get("low_precision_attention", True) is False:
|
||||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||||
|
|
||||||
@ -761,12 +698,6 @@ except AttributeError as error:
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if _is_varlen(kwargs):
|
|
||||||
# q, k, v expected as packed [T_total, H, C].
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
|
|
||||||
return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
|
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,33 +1,7 @@
|
|||||||
import math
|
|
||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.trellis2.vae import VarLenTensor
|
from comfy.ldm.trellis2.vae import VarLenTensor
|
||||||
|
|
||||||
|
|
||||||
def _build_cu_seqlens(seqlen, device):
|
|
||||||
"""Cumulative-offset tensor from a list/tensor of per-item sequence lengths."""
|
|
||||||
if isinstance(seqlen, torch.Tensor):
|
|
||||||
return torch.cat(
|
|
||||||
[torch.zeros(1, dtype=torch.int32, device=seqlen.device), torch.cumsum(seqlen, dim=0).int()],
|
|
||||||
dim=0,
|
|
||||||
).to(device)
|
|
||||||
return torch.cat(
|
|
||||||
[torch.tensor([0]), torch.cumsum(torch.tensor(seqlen), dim=0)],
|
|
||||||
).int().to(device)
|
|
||||||
|
|
||||||
|
|
||||||
def _layout_to_feats(t):
|
|
||||||
"""Returns (sparse_template_or_None, seqlen, feats[T, H, C])."""
|
|
||||||
if isinstance(t, VarLenTensor):
|
|
||||||
seqlen = [t.layout[i].stop - t.layout[i].start for i in range(t.shape[0])]
|
|
||||||
return t, seqlen, t.feats
|
|
||||||
N, L = t.shape[:2]
|
|
||||||
return None, [L] * N, t.reshape(-1, *t.shape[-2:])
|
|
||||||
|
|
||||||
|
|
||||||
def dense_attention(q, k, v, **kwargs):
|
def dense_attention(q, k, v, **kwargs):
|
||||||
"""q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention."""
|
"""q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention."""
|
||||||
heads = q.shape[2]
|
heads = q.shape[2]
|
||||||
@ -38,86 +12,36 @@ def dense_attention(q, k, v, **kwargs):
|
|||||||
return out.permute(0, 2, 1, 3)
|
return out.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_rect(t):
|
||||||
|
"""Fold a VarLenTensor packed as [sum(L_i), H, C] into a dense [B, L, H, C].
|
||||||
|
|
||||||
|
The sparse generation stages run a single object per call (optionally
|
||||||
|
CFG-duplicated, which keeps every batch entry the same length), so the
|
||||||
|
packed layout is rectangular and attention is ordinary dense attention over
|
||||||
|
a batch dim — no variable-length kernel needed. A dense [B, L, H, C] tensor
|
||||||
|
(e.g. cross-attention context) passes through unchanged.
|
||||||
|
"""
|
||||||
|
if not isinstance(t, VarLenTensor):
|
||||||
|
return t
|
||||||
|
B = t.shape[0]
|
||||||
|
seqlens = [t.layout[i].stop - t.layout[i].start for i in range(B)]
|
||||||
|
if len(set(seqlens)) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"trellis2 sparse attention expects equal sequence lengths per batch "
|
||||||
|
f"(single object, optionally CFG-duplicated); got {seqlens}. "
|
||||||
|
"Multi-object batching is not supported."
|
||||||
|
)
|
||||||
|
return t.feats.view(B, seqlens[0], *t.feats.shape[1:])
|
||||||
|
|
||||||
|
|
||||||
def sparse_attention(q, k, v, **kwargs):
|
def sparse_attention(q, k, v, **kwargs):
|
||||||
|
"""Full attention over a SparseTensor's voxels.
|
||||||
|
|
||||||
|
Single object (optionally CFG-duplicated) => the packed layout is
|
||||||
|
rectangular, so we fold it into a batch dim and run ordinary dense
|
||||||
|
attention. Output type matches q.
|
||||||
"""
|
"""
|
||||||
Varlen attention for SparseTensor inputs. Each of q, k, v may be a VarLenTensor
|
out = dense_attention(_to_rect(q), _to_rect(k), _to_rect(v), **kwargs) # [B, Lq, H, C]
|
||||||
(sparse) or dense [B, L, H, C]. Output type matches q. Backend dispatch lives
|
if isinstance(q, VarLenTensor):
|
||||||
in comfy.ldm.modules.attention.optimized_attention; we just build cu_seqlens
|
return q.replace(out.reshape(-1, *out.shape[2:]))
|
||||||
from the layouts.
|
return out
|
||||||
"""
|
|
||||||
s, q_seqlen, q_feats = _layout_to_feats(q)
|
|
||||||
_, kv_seqlen, k_feats = _layout_to_feats(k)
|
|
||||||
_, _, v_feats = _layout_to_feats(v)
|
|
||||||
heads = q_feats.shape[1]
|
|
||||||
|
|
||||||
device = q_feats.device
|
|
||||||
cu_seqlens_q = _build_cu_seqlens(q_seqlen, device)
|
|
||||||
cu_seqlens_kv = _build_cu_seqlens(kv_seqlen, device)
|
|
||||||
|
|
||||||
out = optimized_attention(
|
|
||||||
q_feats, k_feats, v_feats, heads,
|
|
||||||
cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv,
|
|
||||||
max_seqlen_q=max(q_seqlen), max_seqlen_kv=max(kv_seqlen),
|
|
||||||
skip_reshape=True, skip_output_reshape=True,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if s is not None:
|
|
||||||
return s.replace(out)
|
|
||||||
N, L = q.shape[:2]
|
|
||||||
return out.reshape(N, L, heads, -1)
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_windowed_self_attention(qkv, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0)):
|
|
||||||
"""Windowed sparse self-attention. Partitions voxels into windows via spatial
|
|
||||||
sort, then runs varlen attention with one sequence per non-empty window."""
|
|
||||||
cache_name = f'windowed_attention_{window_size}_{shift_window}'
|
|
||||||
cache = qkv.get_spatial_cache(cache_name)
|
|
||||||
if cache is None:
|
|
||||||
cache = calc_window_partition(qkv, window_size, shift_window)
|
|
||||||
qkv.register_spatial_cache(cache_name, cache)
|
|
||||||
fwd_indices, bwd_indices, seq_lens = cache
|
|
||||||
|
|
||||||
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
|
||||||
q, k, v = qkv_feats.unbind(dim=1) # each [M, H, C]
|
|
||||||
heads = q.shape[1]
|
|
||||||
device = q.device
|
|
||||||
|
|
||||||
cu_seqlens = _build_cu_seqlens(seq_lens, device)
|
|
||||||
max_seqlen = int(seq_lens.max())
|
|
||||||
|
|
||||||
out = optimized_attention(
|
|
||||||
q, k, v, heads,
|
|
||||||
cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens,
|
|
||||||
max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen,
|
|
||||||
skip_reshape=True, skip_output_reshape=True,
|
|
||||||
)
|
|
||||||
out = out[bwd_indices]
|
|
||||||
return qkv.replace(out)
|
|
||||||
|
|
||||||
|
|
||||||
def calc_window_partition(
|
|
||||||
tensor,
|
|
||||||
window_size: Union[int, Tuple[int, ...]],
|
|
||||||
shift_window: Union[int, Tuple[int, ...]] = 0,
|
|
||||||
):
|
|
||||||
"""Returns (fwd_indices, bwd_indices, seq_lens) for window partitioning."""
|
|
||||||
DIM = tensor.coords.shape[1] - 1
|
|
||||||
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
|
|
||||||
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
|
|
||||||
shifted_coords = tensor.coords.clone().detach()
|
|
||||||
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
|
||||||
|
|
||||||
MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
|
|
||||||
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
|
|
||||||
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
|
|
||||||
|
|
||||||
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
|
||||||
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
|
|
||||||
fwd_indices = torch.argsort(shifted_indices)
|
|
||||||
bwd_indices = torch.empty_like(fwd_indices)
|
|
||||||
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
|
|
||||||
seq_lens = torch.bincount(shifted_indices)
|
|
||||||
seq_lens = seq_lens[seq_lens != 0]
|
|
||||||
|
|
||||||
return fwd_indices, bwd_indices, seq_lens
|
|
||||||
|
|||||||
@ -3,8 +3,6 @@ import torch
|
|||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
UINT32_SENTINEL = 0xFFFFFFFF
|
|
||||||
|
|
||||||
|
|
||||||
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
||||||
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
|
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
|
||||||
@ -19,11 +17,9 @@ def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
|||||||
class TorchHashMap:
|
class TorchHashMap:
|
||||||
"""Sorted-array hashmap backed by torch.searchsorted."""
|
"""Sorted-array hashmap backed by torch.searchsorted."""
|
||||||
|
|
||||||
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
def __init__(self, keys: torch.Tensor, values: torch.Tensor):
|
||||||
device = keys.device
|
|
||||||
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
||||||
self.sorted_vals = values.to(torch.long)[order]
|
self.sorted_vals = values.to(torch.long)[order]
|
||||||
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
|
||||||
self._n = self.sorted_keys.numel()
|
self._n = self.sorted_keys.numel()
|
||||||
|
|
||||||
# Chunk size for lookup_flat, caps each transient to ~CHUNK rows.
|
# Chunk size for lookup_flat, caps each transient to ~CHUNK rows.
|
||||||
@ -54,65 +50,28 @@ def build_submanifold_neighbor_map(
|
|||||||
Kw, Kh, Kd,
|
Kw, Kh, Kd,
|
||||||
Dw, Dh, Dd,
|
Dw, Dh, Dd,
|
||||||
):
|
):
|
||||||
|
# neighbor[i, v] = index of the voxel at voxel i's coord + kernel-offset v, or -1.
|
||||||
|
# Chunked over voxels so the [chunk, V, 3] candidate transient stays bounded.
|
||||||
device = coords.device
|
device = coords.device
|
||||||
M = coords.shape[0]
|
M = coords.shape[0]
|
||||||
V = Kw * Kh * Kd
|
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device).long() # [V, 3]
|
||||||
half_V = V // 2 + 1
|
V = offsets.shape[0]
|
||||||
INVALID = -1
|
center = torch.tensor([(Kw // 2) * Dw, (Kh // 2) * Dh, (Kd // 2) * Dd], device=device)
|
||||||
|
WHD, HD = W * H * D, H * D
|
||||||
|
|
||||||
# int32 neighbour map: 4 bytes/elem vs 8 bytes for int64
|
neighbor = torch.empty((M, V), dtype=torch.int32, device=device)
|
||||||
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32)
|
# ~V*40 bytes/voxel of transient (int64 cand + flat + masks); cap at ~0.5 GB.
|
||||||
|
chunk = max(1, min(M, int(0.5 * (1024 ** 3) / (V * 40))))
|
||||||
b = coords[:, 0].long()
|
|
||||||
x = coords[:, 1].long()
|
|
||||||
y = coords[:, 2].long()
|
|
||||||
z = coords[:, 3].long()
|
|
||||||
|
|
||||||
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
|
|
||||||
|
|
||||||
ox = x - (Kw // 2) * Dw
|
|
||||||
oy = y - (Kh // 2) * Dh
|
|
||||||
oz = z - (Kd // 2) * Dd
|
|
||||||
|
|
||||||
for v in range(half_V):
|
|
||||||
if v == half_V - 1:
|
|
||||||
# Center voxel always maps to itself
|
|
||||||
neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32)
|
|
||||||
continue
|
|
||||||
|
|
||||||
dx, dy, dz = offsets[v]
|
|
||||||
|
|
||||||
kx = ox + dx
|
|
||||||
ky = oy + dy
|
|
||||||
kz = oz + dz
|
|
||||||
|
|
||||||
valid = (
|
|
||||||
(kx >= 0) & (kx < W) &
|
|
||||||
(ky >= 0) & (ky < H) &
|
|
||||||
(kz >= 0) & (kz < D)
|
|
||||||
)
|
|
||||||
|
|
||||||
flat = (
|
|
||||||
b[valid] * (W * H * D) +
|
|
||||||
kx[valid] * (H * D) +
|
|
||||||
ky[valid] * D +
|
|
||||||
kz[valid]
|
|
||||||
)
|
|
||||||
|
|
||||||
if flat.numel() > 0:
|
|
||||||
found = hashmap.lookup_flat(flat)
|
|
||||||
idx_in_M = torch.where(valid)[0]
|
|
||||||
neighbor[idx_in_M, v] = found.to(torch.int32)
|
|
||||||
|
|
||||||
# BUG FIX: old code used found != hashmap.default_value which
|
|
||||||
# compared int32 -1 against int64 4294967295 → always True.
|
|
||||||
# We now explicitly check for valid indices.
|
|
||||||
valid_found_mask = found >= 0
|
|
||||||
if valid_found_mask.any():
|
|
||||||
src_points = idx_in_M[valid_found_mask]
|
|
||||||
dst_points = found[valid_found_mask].long()
|
|
||||||
neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32)
|
|
||||||
|
|
||||||
|
for s in range(0, M, chunk):
|
||||||
|
e = min(s + chunk, M)
|
||||||
|
b = coords[s:e, 0].long()
|
||||||
|
cand = coords[s:e, 1:4].long()[:, None, :] + offsets[None, :, :] - center # [c, V, 3]
|
||||||
|
x, y, z = cand[..., 0], cand[..., 1], cand[..., 2]
|
||||||
|
in_bounds = (x >= 0) & (x < W) & (y >= 0) & (y < H) & (z >= 0) & (z < D) # [c, V]
|
||||||
|
flat = b[:, None] * WHD + x * HD + y * D + z # [c, V]
|
||||||
|
flat = torch.where(in_bounds, flat, torch.full_like(flat, -1)) # OOB -> guaranteed miss
|
||||||
|
neighbor[s:e] = hashmap.lookup_flat(flat.reshape(-1)).view(e - s, V)
|
||||||
return neighbor
|
return neighbor
|
||||||
|
|
||||||
def get_recommended_chunk_mem(
|
def get_recommended_chunk_mem(
|
||||||
@ -155,7 +114,7 @@ def sparse_submanifold_conv3d(
|
|||||||
coords[:, 2].long() * y_stride +
|
coords[:, 2].long() * y_stride +
|
||||||
coords[:, 3].long() * z_stride)
|
coords[:, 3].long() * z_stride)
|
||||||
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
|
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
|
||||||
hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL)
|
hashmap = TorchHashMap(flat_keys, vals)
|
||||||
|
|
||||||
neighbor = build_submanifold_neighbor_map(
|
neighbor = build_submanifold_neighbor_map(
|
||||||
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
||||||
@ -165,13 +124,15 @@ def sparse_submanifold_conv3d(
|
|||||||
neighbor = neighbor_cache
|
neighbor = neighbor_cache
|
||||||
|
|
||||||
N_pts = feats.shape[0]
|
N_pts = feats.shape[0]
|
||||||
sentinel = -1
|
|
||||||
|
|
||||||
weight_T = weight.view(Co, V * Ci).T
|
weight_T = weight.view(Co, V * Ci).T
|
||||||
|
|
||||||
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
|
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
|
||||||
|
|
||||||
# Chunk size from memory budget. The dominant peak is `gathered`, of shape (chunk, V, Ci) in feats.dtype.
|
# Zero row at index N_pts; missing neighbors (-1) gather it -> no separate masking.
|
||||||
|
feats_padded = torch.cat([feats, feats.new_zeros(1, Ci)], dim=0)
|
||||||
|
|
||||||
|
# Chunk over voxels to bound the (chunk, V, Ci) gather.
|
||||||
max_chunk_mem_gb = get_recommended_chunk_mem(device)
|
max_chunk_mem_gb = get_recommended_chunk_mem(device)
|
||||||
mem_per_row = V * Ci * feats.element_size()
|
mem_per_row = V * Ci * feats.element_size()
|
||||||
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
|
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
|
||||||
@ -182,18 +143,10 @@ def sparse_submanifold_conv3d(
|
|||||||
end = min(start + chunk_size, N_pts)
|
end = min(start + chunk_size, N_pts)
|
||||||
actual_chunk = end - start
|
actual_chunk = end - start
|
||||||
|
|
||||||
chunk_neighbor = neighbor[start:end]
|
chunk_idx = torch.where(neighbor[start:end] < 0, N_pts, neighbor[start:end]) # -1 -> zero row
|
||||||
chunk_valid = chunk_neighbor != sentinel
|
gathered = feats_padded[chunk_idx] # (chunk, V, Ci)
|
||||||
# clamp(-1 -> 0) keeps invalid indices in-range so the gather is safe
|
|
||||||
chunk_idx = chunk_neighbor.clamp(min=0)
|
|
||||||
|
|
||||||
# (chunk, V, Ci) gather, then in-place zero of invalid neighbors.
|
|
||||||
gathered = feats[chunk_idx]
|
|
||||||
gathered.mul_(chunk_valid.unsqueeze(-1))
|
|
||||||
|
|
||||||
# GEMM (chunk, V*Ci) @ (V*Ci, Co) -> (chunk, Co), written to output[start:end].
|
|
||||||
gathered_flat = gathered.view(actual_chunk, V * Ci)
|
gathered_flat = gathered.view(actual_chunk, V * Ci)
|
||||||
torch.matmul(gathered_flat, weight_T, out=output[start:end])
|
torch.matmul(gathered_flat, weight_T, out=output[start:end]) # (chunk, V*Ci) @ (V*Ci, Co)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output += bias.unsqueeze(0).to(output.dtype)
|
output += bias.unsqueeze(0).to(output.dtype)
|
||||||
|
|||||||
@ -3,12 +3,9 @@ import torch.nn.functional as F
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
|
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
|
||||||
from typing import Optional, Tuple, Literal, Union, List
|
from typing import Optional, Tuple, Literal, Union, List
|
||||||
from comfy.ldm.trellis2.attention import (
|
from comfy.ldm.trellis2.attention import sparse_attention, dense_attention
|
||||||
sparse_windowed_self_attention, sparse_attention, dense_attention
|
|
||||||
)
|
|
||||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||||
from comfy.ldm.trellis2 import sampling_preview
|
|
||||||
|
|
||||||
class SparseGELU(nn.GELU):
|
class SparseGELU(nn.GELU):
|
||||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||||
@ -26,13 +23,6 @@ class SparseFeedForwardNet(nn.Module):
|
|||||||
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
||||||
return self.mlp(x)
|
return self.mlp(x)
|
||||||
|
|
||||||
class LayerNorm32(nn.LayerNorm):
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x_dtype = x.dtype
|
|
||||||
x = x.to(dtype=torch.float32)
|
|
||||||
o = super().forward(x)
|
|
||||||
return o.to(dtype=x_dtype)
|
|
||||||
|
|
||||||
class SparseMultiHeadRMSNorm(nn.Module):
|
class SparseMultiHeadRMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, heads: int, device, dtype):
|
def __init__(self, dim: int, heads: int, device, dtype):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -44,13 +34,7 @@ class SparseMultiHeadRMSNorm(nn.Module):
|
|||||||
return F.rms_norm(x, (x.shape[-1],)) * self.gamma
|
return F.rms_norm(x, (x.shape[-1],)) * self.gamma
|
||||||
|
|
||||||
class SparseRotaryPositionEmbedder(nn.Module):
|
class SparseRotaryPositionEmbedder(nn.Module):
|
||||||
def __init__(
|
def __init__(self, head_dim: int, dim: int = 3, rope_freq: Tuple[float, float] = (1.0, 10000.0), device=None):
|
||||||
self,
|
|
||||||
head_dim: int,
|
|
||||||
dim: int = 3,
|
|
||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
|
||||||
device=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -111,12 +95,7 @@ class SparseMultiHeadAttention(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
ctx_channels: Optional[int] = None,
|
ctx_channels: Optional[int] = None,
|
||||||
type: Literal["self", "cross"] = "self",
|
type: Literal["self", "cross"] = "self",
|
||||||
attn_mode: Literal["full", "windowed", "double_windowed"] = "full",
|
|
||||||
window_size: Optional[int] = None,
|
|
||||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
use_rope: bool = False,
|
|
||||||
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
|
||||||
qk_rms_norm: bool = False,
|
qk_rms_norm: bool = False,
|
||||||
device=None, dtype=None, operations=None
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
@ -127,10 +106,6 @@ class SparseMultiHeadAttention(nn.Module):
|
|||||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self._type = type
|
self._type = type
|
||||||
self.attn_mode = attn_mode
|
|
||||||
self.window_size = window_size
|
|
||||||
self.shift_window = shift_window
|
|
||||||
self.use_rope = use_rope
|
|
||||||
self.qk_rms_norm = qk_rms_norm
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
|
||||||
if self._type == "self":
|
if self._type == "self":
|
||||||
@ -145,8 +120,8 @@ class SparseMultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
if use_rope:
|
if self._type == "self":
|
||||||
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device)
|
self.rope = SparseRotaryPositionEmbedder(self.head_dim, device=device)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||||
@ -170,43 +145,16 @@ class SparseMultiHeadAttention(nn.Module):
|
|||||||
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
|
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
|
||||||
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
|
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
|
||||||
|
|
||||||
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None,
|
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None, transformer_options=None) -> SparseTensor:
|
||||||
transformer_options=None) -> SparseTensor:
|
|
||||||
if self._type == "self":
|
if self._type == "self":
|
||||||
qkv = self._linear(self.to_qkv, x)
|
qkv = self._linear(self.to_qkv, x)
|
||||||
qkv = self._fused_pre(qkv, num_fused=3)
|
qkv = self._fused_pre(qkv, num_fused=3)
|
||||||
if self.attn_mode == "full":
|
q, k, v = qkv.unbind(dim=-3)
|
||||||
q, k, v = qkv.unbind(dim=-3)
|
if self.qk_rms_norm:
|
||||||
if self.qk_rms_norm:
|
q = self.q_rms_norm(q)
|
||||||
q = self.q_rms_norm(q)
|
k = self.k_rms_norm(k)
|
||||||
k = self.k_rms_norm(k)
|
q, k = self.rope(q, k)
|
||||||
if self.use_rope:
|
h = sparse_attention(q, k, v, transformer_options=transformer_options)
|
||||||
q, k = self.rope(q, k)
|
|
||||||
h = sparse_attention(q, k, v, transformer_options=transformer_options)
|
|
||||||
else:
|
|
||||||
# Windowed paths take packed qkv; preserve any per-head norm/rope.
|
|
||||||
if self.qk_rms_norm or self.use_rope:
|
|
||||||
q, k, v = qkv.unbind(dim=-3)
|
|
||||||
if self.qk_rms_norm:
|
|
||||||
q = self.q_rms_norm(q)
|
|
||||||
k = self.k_rms_norm(k)
|
|
||||||
if self.use_rope:
|
|
||||||
q, k = self.rope(q, k)
|
|
||||||
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
|
|
||||||
if self.attn_mode == "windowed":
|
|
||||||
h = sparse_windowed_self_attention(
|
|
||||||
qkv, self.window_size, shift_window=self.shift_window
|
|
||||||
)
|
|
||||||
elif self.attn_mode == "double_windowed":
|
|
||||||
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
|
|
||||||
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
|
|
||||||
h0 = sparse_windowed_self_attention(
|
|
||||||
qkv0, self.window_size, shift_window=(0, 0, 0)
|
|
||||||
)
|
|
||||||
h1 = sparse_windowed_self_attention(
|
|
||||||
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
|
|
||||||
)
|
|
||||||
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
|
|
||||||
else:
|
else:
|
||||||
q = self._linear(self.to_q, x)
|
q = self._linear(self.to_q, x)
|
||||||
q = self._reshape_chs(q, (self.num_heads, -1))
|
q = self._reshape_chs(q, (self.num_heads, -1))
|
||||||
@ -276,37 +224,25 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
|||||||
ctx_channels: int,
|
ctx_channels: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
attn_mode: Literal["full", "swin"] = "full",
|
|
||||||
window_size: Optional[int] = None,
|
|
||||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
|
||||||
use_checkpoint: bool = False,
|
|
||||||
use_rope: bool = False,
|
|
||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
|
||||||
qk_rms_norm: bool = False,
|
qk_rms_norm: bool = False,
|
||||||
qk_rms_norm_cross: bool = False,
|
qk_rms_norm_cross: bool = False,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
share_mod: bool = False,
|
share_mod: bool = False,
|
||||||
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
|
image_attn_mode: Literal["global", "proj"] = "global",
|
||||||
proj_in_channels: Optional[int] = None,
|
proj_in_channels: Optional[int] = None,
|
||||||
device=None, dtype=None, operations=None
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.share_mod = share_mod
|
self.share_mod = share_mod
|
||||||
self.image_attn_mode = image_attn_mode
|
self.image_attn_mode = image_attn_mode
|
||||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||||
self.self_attn = SparseMultiHeadAttention(
|
self.self_attn = SparseMultiHeadAttention(
|
||||||
channels,
|
channels,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
type="self",
|
type="self",
|
||||||
attn_mode=attn_mode,
|
|
||||||
window_size=window_size,
|
|
||||||
shift_window=shift_window,
|
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
use_rope=use_rope,
|
|
||||||
rope_freq=rope_freq,
|
|
||||||
qk_rms_norm=qk_rms_norm,
|
qk_rms_norm=qk_rms_norm,
|
||||||
device=device, dtype=dtype, operations=operations
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
@ -315,7 +251,6 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
|||||||
ctx_channels=ctx_channels,
|
ctx_channels=ctx_channels,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
type="cross",
|
type="cross",
|
||||||
attn_mode="full",
|
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_rms_norm=qk_rms_norm_cross,
|
qk_rms_norm=qk_rms_norm_cross,
|
||||||
device=device, dtype=dtype, operations=operations
|
device=device, dtype=dtype, operations=operations
|
||||||
@ -347,8 +282,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
|||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||||
else:
|
else:
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||||
# Fuse the (mul + add) and (mul + residual) pairs into addcmul so the
|
# Fuse the (mul + add) and (mul + residual) pairs into addcmul
|
||||||
# mod/shift broadcasts hit one kernel each instead of two.
|
|
||||||
b_map = x.batch_boardcast_map
|
b_map = x.batch_boardcast_map
|
||||||
|
|
||||||
h_feats = self.norm1(x.feats)
|
h_feats = self.norm1(x.feats)
|
||||||
@ -386,18 +320,12 @@ class SLatFlowModel(nn.Module):
|
|||||||
num_heads: Optional[int] = None,
|
num_heads: Optional[int] = None,
|
||||||
num_head_channels: Optional[int] = 64,
|
num_head_channels: Optional[int] = 64,
|
||||||
mlp_ratio: float = 4,
|
mlp_ratio: float = 4,
|
||||||
pe_mode: Literal["ape", "rope"] = "rope",
|
|
||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
|
||||||
use_checkpoint: bool = False,
|
|
||||||
share_mod: bool = False,
|
share_mod: bool = False,
|
||||||
initialization: str = 'vanilla',
|
|
||||||
qk_rms_norm: bool = False,
|
qk_rms_norm: bool = False,
|
||||||
qk_rms_norm_cross: bool = False,
|
qk_rms_norm_cross: bool = False,
|
||||||
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
|
image_attn_mode: Literal["global", "proj"] = "global",
|
||||||
proj_in_channels: Optional[int] = None,
|
proj_in_channels: Optional[int] = None,
|
||||||
dtype = None,
|
dtype = None, device = None, operations = None,
|
||||||
device = None,
|
|
||||||
operations = None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
@ -408,10 +336,7 @@ class SLatFlowModel(nn.Module):
|
|||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.num_heads = num_heads or model_channels // num_head_channels
|
self.num_heads = num_heads or model_channels // num_head_channels
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
self.pe_mode = pe_mode
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.share_mod = share_mod
|
self.share_mod = share_mod
|
||||||
self.initialization = initialization
|
|
||||||
self.qk_rms_norm = qk_rms_norm
|
self.qk_rms_norm = qk_rms_norm
|
||||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||||
self.image_attn_mode = image_attn_mode
|
self.image_attn_mode = image_attn_mode
|
||||||
@ -433,10 +358,6 @@ class SLatFlowModel(nn.Module):
|
|||||||
cond_channels,
|
cond_channels,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
attn_mode='full',
|
|
||||||
use_checkpoint=self.use_checkpoint,
|
|
||||||
use_rope=(pe_mode == "rope"),
|
|
||||||
rope_freq=rope_freq,
|
|
||||||
share_mod=self.share_mod,
|
share_mod=self.share_mod,
|
||||||
qk_rms_norm=self.qk_rms_norm,
|
qk_rms_norm=self.qk_rms_norm,
|
||||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||||
@ -491,14 +412,22 @@ class FeedForwardNet(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.mlp(x)
|
return self.mlp(x)
|
||||||
|
|
||||||
|
# class MultiHeadRMSNorm(nn.Module):
|
||||||
|
# def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||||
|
# super().__init__()
|
||||||
|
# self.scale = dim ** 0.5
|
||||||
|
# self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
# def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
||||||
|
|
||||||
class MultiHeadRMSNorm(nn.Module):
|
class MultiHeadRMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim ** 0.5
|
|
||||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
return (F.rms_norm(x.float(), (x.shape[-1],)) * self.gamma).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
@ -508,12 +437,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
ctx_channels: Optional[int]=None,
|
ctx_channels: Optional[int]=None,
|
||||||
type: Literal["self", "cross"] = "self",
|
type: Literal["self", "cross"] = "self",
|
||||||
attn_mode: Literal["full", "windowed"] = "full",
|
|
||||||
window_size: Optional[int] = None,
|
|
||||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
use_rope: bool = False,
|
|
||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
|
||||||
qk_rms_norm: bool = False,
|
qk_rms_norm: bool = False,
|
||||||
device=None, dtype=None, operations=None
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
@ -524,10 +448,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self._type = type
|
self._type = type
|
||||||
self.attn_mode = attn_mode
|
|
||||||
self.window_size = window_size
|
|
||||||
self.shift_window = shift_window
|
|
||||||
self.use_rope = use_rope
|
|
||||||
self.qk_rms_norm = qk_rms_norm
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
|
||||||
if self._type == "self":
|
if self._type == "self":
|
||||||
@ -552,12 +472,11 @@ class MultiHeadAttention(nn.Module):
|
|||||||
if self.qk_rms_norm:
|
if self.qk_rms_norm:
|
||||||
q = self.q_rms_norm(q)
|
q = self.q_rms_norm(q)
|
||||||
k = self.k_rms_norm(k)
|
k = self.k_rms_norm(k)
|
||||||
if self.use_rope:
|
assert phases is not None, "Phases must be provided for RoPE"
|
||||||
assert phases is not None, "Phases must be provided for RoPE"
|
# phases is [L, head_dim/2, 2, 2]; broadcast to [1, L, 1, ...]
|
||||||
# phases is [L, head_dim/2, 2, 2]; broadcast to [1, L, 1, ...]
|
# to align with q/k of shape [B, L, H, head_dim].
|
||||||
# to align with q/k of shape [B, L, H, head_dim].
|
f_cis = phases.unsqueeze(0).unsqueeze(2)
|
||||||
f_cis = phases.unsqueeze(0).unsqueeze(2)
|
q, k = apply_rope(q, k, f_cis)
|
||||||
q, k = apply_rope(q, k, f_cis)
|
|
||||||
h = dense_attention(q, k, v, transformer_options=transformer_options)
|
h = dense_attention(q, k, v, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
Lkv = context.shape[1]
|
Lkv = context.shape[1]
|
||||||
@ -581,37 +500,25 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
|||||||
ctx_channels: int,
|
ctx_channels: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
attn_mode: Literal["full", "windowed"] = "full",
|
|
||||||
window_size: Optional[int] = None,
|
|
||||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
|
||||||
use_checkpoint: bool = False,
|
|
||||||
use_rope: bool = False,
|
|
||||||
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
|
||||||
qk_rms_norm: bool = False,
|
qk_rms_norm: bool = False,
|
||||||
qk_rms_norm_cross: bool = False,
|
qk_rms_norm_cross: bool = False,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
share_mod: bool = False,
|
share_mod: bool = False,
|
||||||
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
|
image_attn_mode: Literal["global", "proj"] = "global",
|
||||||
proj_in_channels: Optional[int] = None,
|
proj_in_channels: Optional[int] = None,
|
||||||
device=None, dtype=None, operations=None
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.share_mod = share_mod
|
self.share_mod = share_mod
|
||||||
self.image_attn_mode = image_attn_mode
|
self.image_attn_mode = image_attn_mode
|
||||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||||
self.self_attn = MultiHeadAttention(
|
self.self_attn = MultiHeadAttention(
|
||||||
channels,
|
channels,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
type="self",
|
type="self",
|
||||||
attn_mode=attn_mode,
|
|
||||||
window_size=window_size,
|
|
||||||
shift_window=shift_window,
|
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
use_rope=use_rope,
|
|
||||||
rope_freq=rope_freq,
|
|
||||||
qk_rms_norm=qk_rms_norm,
|
qk_rms_norm=qk_rms_norm,
|
||||||
device=device, dtype=dtype, operations=operations
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
@ -620,7 +527,6 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
|||||||
ctx_channels=ctx_channels,
|
ctx_channels=ctx_channels,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
type="cross",
|
type="cross",
|
||||||
attn_mode="full",
|
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_rms_norm=qk_rms_norm_cross,
|
qk_rms_norm=qk_rms_norm_cross,
|
||||||
device=device, dtype=dtype, operations=operations
|
device=device, dtype=dtype, operations=operations
|
||||||
@ -640,25 +546,17 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
|||||||
device=device, dtype=dtype, operations=operations
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
if not share_mod:
|
if not share_mod:
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||||
nn.SiLU(),
|
|
||||||
operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||||
|
|
||||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context,
|
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context,
|
||||||
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
|
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
|
||||||
if self.share_mod:
|
if self.share_mod:
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
mod = (self.modulation + mod).type(mod.dtype)
|
||||||
else:
|
else:
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
mod = self.adaLN_modulation(mod)
|
||||||
shift_msa = shift_msa.unsqueeze(1)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(6, dim=-1)
|
||||||
scale_msa = scale_msa.unsqueeze(1)
|
|
||||||
gate_msa = gate_msa.unsqueeze(1)
|
|
||||||
shift_mlp = shift_mlp.unsqueeze(1)
|
|
||||||
scale_mlp = scale_mlp.unsqueeze(1)
|
|
||||||
gate_mlp = gate_mlp.unsqueeze(1)
|
|
||||||
|
|
||||||
h = torch.addcmul(shift_msa, self.norm1(x), 1 + scale_msa)
|
h = torch.addcmul(shift_msa, self.norm1(x), 1 + scale_msa)
|
||||||
h = self.self_attn(h, phases=phases, transformer_options=transformer_options)
|
h = self.self_attn(h, phases=phases, transformer_options=transformer_options)
|
||||||
@ -694,14 +592,10 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
num_heads: Optional[int] = None,
|
num_heads: Optional[int] = None,
|
||||||
num_head_channels: Optional[int] = 64,
|
num_head_channels: Optional[int] = 64,
|
||||||
mlp_ratio: float = 4,
|
mlp_ratio: float = 4,
|
||||||
pe_mode: Literal["ape", "rope"] = "rope",
|
|
||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
|
||||||
use_checkpoint: bool = False,
|
|
||||||
share_mod: bool = False,
|
share_mod: bool = False,
|
||||||
initialization: str = 'vanilla',
|
|
||||||
qk_rms_norm: bool = False,
|
qk_rms_norm: bool = False,
|
||||||
qk_rms_norm_cross: bool = False,
|
qk_rms_norm_cross: bool = False,
|
||||||
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
|
image_attn_mode: Literal["global", "proj"] = "global",
|
||||||
proj_in_channels: Optional[int] = None,
|
proj_in_channels: Optional[int] = None,
|
||||||
operations=None,
|
operations=None,
|
||||||
device = None,
|
device = None,
|
||||||
@ -718,10 +612,7 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.num_heads = num_heads or model_channels // num_head_channels
|
self.num_heads = num_heads or model_channels // num_head_channels
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
self.pe_mode = pe_mode
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.share_mod = share_mod
|
self.share_mod = share_mod
|
||||||
self.initialization = initialization
|
|
||||||
self.qk_rms_norm = qk_rms_norm
|
self.qk_rms_norm = qk_rms_norm
|
||||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||||
self.image_attn_mode = image_attn_mode
|
self.image_attn_mode = image_attn_mode
|
||||||
@ -742,9 +633,6 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
rope_phases = pos_embedder(coords)
|
rope_phases = pos_embedder(coords)
|
||||||
self.register_buffer("rope_phases", rope_phases, persistent=False)
|
self.register_buffer("rope_phases", rope_phases, persistent=False)
|
||||||
|
|
||||||
if pe_mode != "rope":
|
|
||||||
self.rope_phases = None
|
|
||||||
|
|
||||||
self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype)
|
self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
@ -753,10 +641,6 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
cond_channels,
|
cond_channels,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
attn_mode='full',
|
|
||||||
use_checkpoint=self.use_checkpoint,
|
|
||||||
use_rope=(pe_mode == "rope"),
|
|
||||||
rope_freq=rope_freq,
|
|
||||||
share_mod=share_mod,
|
share_mod=share_mod,
|
||||||
qk_rms_norm=self.qk_rms_norm,
|
qk_rms_norm=self.qk_rms_norm,
|
||||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||||
@ -788,18 +672,10 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
|
|
||||||
return h
|
return h
|
||||||
|
|
||||||
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
|
||||||
t_shifted = t_shifted / 1000.0
|
|
||||||
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
|
|
||||||
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
|
|
||||||
t_new *= 1000.0
|
|
||||||
return t_new
|
|
||||||
|
|
||||||
|
# Pixal3D ProjGrid math
|
||||||
# Pixal3D ProjGrid math — port of upstream's ProjGrid + project_points_to_image_batch.
|
# World frame uses world Y as depth, camera looks along -Z local
|
||||||
# World frame uses world Y as depth, camera looks along -Z local;
|
# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2) with sensor_width = 32mm.
|
||||||
# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2)
|
|
||||||
# with sensor_width = 32mm.
|
|
||||||
|
|
||||||
_PROJ_GRID_ROTATION = torch.tensor(
|
_PROJ_GRID_ROTATION = torch.tensor(
|
||||||
[[1.0, 0.0, 0.0],
|
[[1.0, 0.0, 0.0],
|
||||||
@ -815,7 +691,7 @@ _PROJ_FRONT_VIEW_TRANSFORM = torch.tensor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_proj_transform_matrix(distance: torch.Tensor, batch_size: int,
|
def build_proj_transform_matrix(distance: torch.Tensor, batch_size: int,
|
||||||
device, dtype=torch.float32) -> torch.Tensor:
|
device, dtype=torch.float32) -> torch.Tensor:
|
||||||
T = _PROJ_FRONT_VIEW_TRANSFORM.to(device=device, dtype=dtype)
|
T = _PROJ_FRONT_VIEW_TRANSFORM.to(device=device, dtype=dtype)
|
||||||
T = T.unsqueeze(0).expand(batch_size, -1, -1).clone()
|
T = T.unsqueeze(0).expand(batch_size, -1, -1).clone()
|
||||||
@ -849,8 +725,7 @@ def _project_points_to_image(points_world: torch.Tensor, transform_matrix: torch
|
|||||||
def _sample_features(feature_map: torch.Tensor, uv_ndc: torch.Tensor) -> torch.Tensor:
|
def _sample_features(feature_map: torch.Tensor, uv_ndc: torch.Tensor) -> torch.Tensor:
|
||||||
B, C, _, _ = feature_map.shape
|
B, C, _, _ = feature_map.shape
|
||||||
grid = uv_ndc.view(B, -1, 1, 2).to(feature_map.dtype)
|
grid = uv_ndc.view(B, -1, 1, 2).to(feature_map.dtype)
|
||||||
feat = F.grid_sample(feature_map, grid, mode="bilinear",
|
feat = F.grid_sample(feature_map, grid, mode="bilinear", padding_mode="border", align_corners=False)
|
||||||
padding_mode="border", align_corners=False)
|
|
||||||
return feat.squeeze(-1)
|
return feat.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
@ -922,8 +797,6 @@ def _back_project_to_tokens(
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _select_stage_entry(proj_pack: dict, stage: Optional[str]):
|
def _select_stage_entry(proj_pack: dict, stage: Optional[str]):
|
||||||
"""Returns (feature_map_lr, feature_map_hr_or_None, image_resolution)."""
|
"""Returns (feature_map_lr, feature_map_hr_or_None, image_resolution)."""
|
||||||
stages = proj_pack.get("stages")
|
stages = proj_pack.get("stages")
|
||||||
@ -1025,7 +898,7 @@ def _shape_proj_cond(global_cond: torch.Tensor, image_attn_mode: str,
|
|||||||
else:
|
else:
|
||||||
proj_feats = proj_feats.repeat((repeats, 1))
|
proj_feats = proj_feats.repeat((repeats, 1))
|
||||||
|
|
||||||
# Mirror upstream's neg_cond by zeroing proj for any uncond batch slot.
|
# zero proj for any uncond batch slot
|
||||||
if cond_or_uncond is not None and eval_batch is not None:
|
if cond_or_uncond is not None and eval_batch is not None:
|
||||||
uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1]
|
uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1]
|
||||||
if uncond_slots:
|
if uncond_slots:
|
||||||
@ -1100,9 +973,6 @@ class Trellis2(nn.Module):
|
|||||||
# Pre-computed per-stage back-projected features
|
# Pre-computed per-stage back-projected features
|
||||||
proj_feats = kwargs.get("trellis2_proj_feats")
|
proj_feats = kwargs.get("trellis2_proj_feats")
|
||||||
|
|
||||||
sampling_preview.set_context(mode=mode, coords=coords, coord_counts=coord_counts,
|
|
||||||
model_frame=kwargs.get("trellis2_model_frame"))
|
|
||||||
|
|
||||||
is_first_shape_pass = False
|
is_first_shape_pass = False
|
||||||
if mode == "shape_generation_512":
|
if mode == "shape_generation_512":
|
||||||
is_first_shape_pass = True
|
is_first_shape_pass = True
|
||||||
|
|||||||
@ -75,153 +75,10 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
|||||||
return preview_to_image(latent_image)
|
return preview_to_image(latent_image)
|
||||||
|
|
||||||
|
|
||||||
class Trellis3DPreviewer(LatentPreviewer):
|
|
||||||
"""Per-step preview for the Trellis2/Pixal3D cascade.
|
|
||||||
|
|
||||||
Structure stage: x0 is a dense [B, 32, 16, 16, 16] grid — project the per-cell
|
|
||||||
activation norm orthographically to a 2D occupancy heatmap (no decode, no coords).
|
|
||||||
Texture stage: x0 is sparse [B, 32, N, 1] — splat the first 3 latent channels as
|
|
||||||
pseudo-color onto the fixed voxel coords (read from the sampling side-channel).
|
|
||||||
Shape stage adds no visible motion (coords are fixed, only sub-voxel detail
|
|
||||||
evolves) and a full decode per step is too costly, so it's skipped.
|
|
||||||
|
|
||||||
Both stages render through one orthographic point splatter (static view).
|
|
||||||
"""
|
|
||||||
_SIZE = 128
|
|
||||||
_FILL = 0.9 # fraction of frame the texture splat spans (leaves a border)
|
|
||||||
_STRUCTURE_ZOOM = 0.66 # <1 pulls the SS camera back, leaving margin around the blob
|
|
||||||
|
|
||||||
def _splat(self, points, colors, rad):
|
|
||||||
# points: [K, 3] voxel-index coords. colors: [K, 3] in [0, 1].
|
|
||||||
# Center + isotropic-normalize, project orthographically front-on
|
|
||||||
# (x->horizontal, y->up, z->depth), then splat a square footprint per point
|
|
||||||
# with one global far->near sort (painter's).
|
|
||||||
S = self._SIZE
|
|
||||||
dev = points.device # keep every tensor here
|
|
||||||
p = points.float()
|
|
||||||
p = p - (p.amax(0) + p.amin(0)) * 0.5
|
|
||||||
p = p / p.abs().amax().clamp(min=1e-8)
|
|
||||||
x, y, z = p[:, 0], p[:, 1], p[:, 2]
|
|
||||||
depth = z # into-screen
|
|
||||||
m = self._FILL
|
|
||||||
u = ((x * m * 0.5 + 0.5) * (S - 1)).long().clamp(0, S - 1)
|
|
||||||
v = (((-y) * m * 0.5 + 0.5) * (S - 1)).long().clamp(0, S - 1) # image up = +y
|
|
||||||
cols = colors.to(dev)
|
|
||||||
us, vs, ds, cs = [], [], [], []
|
|
||||||
for dv in range(-rad, rad + 1):
|
|
||||||
for du in range(-rad, rad + 1):
|
|
||||||
us.append((u + du).clamp(0, S - 1))
|
|
||||||
vs.append((v + dv).clamp(0, S - 1))
|
|
||||||
ds.append(depth)
|
|
||||||
cs.append(cols)
|
|
||||||
order = torch.cat(ds).argsort()
|
|
||||||
img = torch.zeros(S, S, 3, device=dev)
|
|
||||||
img[torch.cat(vs)[order], torch.cat(us)[order]] = torch.cat(cs)[order]
|
|
||||||
return preview_to_image(img, do_scale=False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _turbo(x):
|
|
||||||
# Anton Mikhailov polynomial approximation of the turbo colormap. x: any shape
|
|
||||||
# in [0, 1] -> (..., 3) RGB.
|
|
||||||
x = x.clamp(0.0, 1.0)
|
|
||||||
x2 = x * x; x3 = x2 * x; x4 = x2 * x2; x5 = x4 * x
|
|
||||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
|
||||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
|
||||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
|
||||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
|
||||||
|
|
||||||
def _structure(self, x0):
|
|
||||||
# x0: [B, 32, D, H, W]; the model only consumes the first 8 channels.
|
|
||||||
# Dense orthographic max-projection -> filled occupancy heatmap (turbo-colored,
|
|
||||||
# intensity-weighted so empty space stays black).
|
|
||||||
act = x0[0, :min(8, x0.shape[1])].float().norm(dim=0) # [D, H, W]
|
|
||||||
proj = act.amax(dim=2) # project along one axis
|
|
||||||
proj = (proj - proj.amin()) / (proj.amax() - proj.amin() + 1e-8)
|
|
||||||
inner = max(1, int(round(self._SIZE * self._STRUCTURE_ZOOM)))
|
|
||||||
img = torch.nn.functional.interpolate(proj[None, None], size=(inner, inner), mode="nearest")
|
|
||||||
pad = self._SIZE - inner
|
|
||||||
pl, pt = pad // 2, pad // 2
|
|
||||||
gray = torch.nn.functional.pad(img, (pl, pad - pl, pt, pad - pt))[0, 0] # [S, S], zero margin
|
|
||||||
rgb = self._turbo(gray) * gray.unsqueeze(-1) # [S, S, 3], black where empty
|
|
||||||
return preview_to_image(rgb, do_scale=False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _latent_color(latent):
|
|
||||||
# Prefer the calibrated latent->base_color map (fit from real decoded
|
|
||||||
# albedo by VaeDecodeTextureTrellis); fall back to PCA pseudo-color until
|
|
||||||
# a texture decode has trained it.
|
|
||||||
try:
|
|
||||||
from comfy.ldm.trellis2 import sampling_preview
|
|
||||||
factors = sampling_preview.get_tex_rgb()
|
|
||||||
except Exception:
|
|
||||||
factors = None
|
|
||||||
if factors is not None:
|
|
||||||
W, b = factors
|
|
||||||
rgb = latent @ W.to(latent) + b.to(latent)
|
|
||||||
return rgb.clamp(0, 1)
|
|
||||||
return Trellis3DPreviewer._pca_color(latent)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _pca_color(latent):
|
|
||||||
# latent: [n, C]. Map the 3 directions of maximum variance to RGB.
|
|
||||||
# Higher contrast and more coherent than picking 3 fixed channels.
|
|
||||||
X = latent - latent.mean(dim=0, keepdim=True)
|
|
||||||
cov = (X.transpose(0, 1) @ X) / max(X.shape[0] - 1, 1) # [C, C]
|
|
||||||
_, evecs = torch.linalg.eigh(cov) # ascending eigenvalues
|
|
||||||
pcs = evecs[:, -3:] # [C, 3] top-3 components
|
|
||||||
# Deterministic sign per component (largest-magnitude entry positive) to
|
|
||||||
# stop the preview's hues from flickering as the latent rotates each step.
|
|
||||||
sign = torch.sign(pcs[pcs.abs().argmax(dim=0), torch.arange(3, device=pcs.device)])
|
|
||||||
pcs = pcs * sign.clamp(min=-1.0)
|
|
||||||
proj = X @ pcs # [n, 3]
|
|
||||||
pmin = proj.amin(dim=0, keepdim=True)
|
|
||||||
pmax = proj.amax(dim=0, keepdim=True)
|
|
||||||
return ((proj - pmin) / (pmax - pmin + 1e-8)).clamp(0, 1)
|
|
||||||
|
|
||||||
def _texture(self, x0, coords, model_frame=None):
|
|
||||||
if coords.shape[-1] == 4:
|
|
||||||
b0 = coords[:, 0] == 0
|
|
||||||
spatial = coords[b0][:, 1:4].float()
|
|
||||||
else:
|
|
||||||
spatial = coords[:, :3].float()
|
|
||||||
n0 = spatial.shape[0]
|
|
||||||
if n0 == 0:
|
|
||||||
return None
|
|
||||||
if model_frame == "z_up":
|
|
||||||
spatial = torch.stack([spatial[:, 0], spatial[:, 2], -spatial[:, 1]], dim=-1)
|
|
||||||
latent = x0[0, :, :n0, 0].float().transpose(0, 1) # [n0, C]
|
|
||||||
colors = self._latent_color(latent) # [n0, 3]
|
|
||||||
res = float(spatial.abs().max().item()) + 1.0
|
|
||||||
rad = max(1, int(round(self._SIZE * self._FILL / max(res, 1) / 2)))
|
|
||||||
return self._splat(spatial, colors, rad)
|
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
|
||||||
try:
|
|
||||||
from comfy.ldm.trellis2 import sampling_preview
|
|
||||||
ctx = sampling_preview.get_context()
|
|
||||||
if x0.ndim == 5:
|
|
||||||
return self._structure(x0)
|
|
||||||
mode = ctx.get("mode")
|
|
||||||
coords = ctx.get("coords")
|
|
||||||
if mode == "texture_generation" and coords is not None:
|
|
||||||
return self._texture(x0, coords, model_frame=ctx.get("model_frame"))
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug(f"Trellis3DPreviewer: skipping preview ({e})")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
|
||||||
preview_image = self.decode_latent_to_preview(x0)
|
|
||||||
if preview_image is None:
|
|
||||||
return None
|
|
||||||
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
|
|
||||||
|
|
||||||
|
|
||||||
def get_previewer(device, latent_format):
|
def get_previewer(device, latent_format):
|
||||||
previewer = None
|
previewer = None
|
||||||
method = args.preview_method
|
method = args.preview_method
|
||||||
if method != LatentPreviewMethod.NoPreviews:
|
if method != LatentPreviewMethod.NoPreviews:
|
||||||
if getattr(latent_format, "trellis3d_preview", False):
|
|
||||||
return Trellis3DPreviewer()
|
|
||||||
# TODO previewer methods
|
# TODO previewer methods
|
||||||
taesd_decoder_path = None
|
taesd_decoder_path = None
|
||||||
if latent_format.taesd_decoder_name is not None:
|
if latent_format.taesd_decoder_name is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user