diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 11db46d94..cd8337a7e 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1955,3 +1955,110 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No transformer_options.pop("ar_state", None) return output + + +def _cube_process_logits(logits, top_p, generator): + """Token selection. top_p>=1 or <=0 -> greedy argmax (upstream default, deterministic).""" + if top_p is None or top_p >= 1.0 or top_p <= 0.0: + return torch.argmax(logits, dim=-1, keepdim=True) + sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) + remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p + remove[..., 0] = False + idx_remove = remove.scatter(-1, sorted_idx, remove) + logits = logits.masked_fill(idx_remove, float("-inf")) + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1, generator=generator) + + +@torch.no_grad() +def sample_cube(model, x, sigmas, extra_args=None, callback=None, disable=None, top_p=1.0): + """ + Autoregressive sampler for Roblox Cube3D shape GPT (DualStreamRoformer). + + Not a diffusion sampler: the noised input `x` and `sigmas` values are ignored; + only x's shape (batch, num_tokens) is used. Generates a 1024-long sequence of VQ + token IDs from CLIP text conditioning, with upstream's linearly-decaying CFG and + optional top-p. Plugs into SamplerCustomAdvanced via the SamplerCube node. + + Faithful to cube3d.inference.engine.Engine.run_gpt: + gamma_i = cfg * (T - i) / T ; logits = (1+gamma)*cond - gamma*uncond + fp32 weights + bf16 autocast on cuda. + """ + import comfy.model_management + extra_args = {} if extra_args is None else extra_args + + guider = model.inner_model # CFGGuider + base_model = guider.inner_model # BaseModel (Cube3D) + cube = base_model.diffusion_model + cfg = getattr(guider, "cfg", 3.0) + + def get_cond(name): + conds = guider.conds.get(name, None) + if not conds: + return None + return conds[0]["model_conds"]["c_crossattn"].cond + + pos = get_cond("positive") + neg = get_cond("negative") + if pos is None: + raise ValueError("sample_cube requires positive conditioning (CLIP-L text embeds).") + + device = x.device + weight_dtype = base_model.get_dtype() + T = x.shape[1] + use_cfg = (cfg is not None) and (cfg > 0.0) and (neg is not None) + autocast_enabled = (device.type == "cuda") + cache_dtype = torch.bfloat16 if autocast_enabled else weight_dtype + + def add_bbox(c): + if not getattr(cube, "use_bbox", False): + return c + bbox = torch.zeros((c.shape[0], 3), device=device, dtype=c.dtype) + return torch.cat([c, cube.bbox_proj(bbox).unsqueeze(1)], dim=1) + + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled): + cond = add_bbox(cube.encode_text(pos.to(device=device, dtype=weight_dtype))) + if use_cfg: + ucond = add_bbox(cube.encode_text(neg.to(device=device, dtype=weight_dtype))) + cond = torch.cat([cond, ucond], dim=0) + + bos = torch.full((cond.shape[0], 1), cube.shape_bos_id, dtype=torch.long, device=device) + embed = cube.encode_token(bos) + Bp, input_seq_len, dim = embed.shape + embed_buffer = torch.zeros((Bp, input_seq_len + T, dim), dtype=embed.dtype, device=device) + embed_buffer[:, :input_seq_len, :].copy_(embed) + + kv_cache = cube.init_kv_cache(Bp, cond.shape[1], T + 1, cache_dtype, device) + + num_codes = cube.vocab_size - 3 + seed = extra_args.get("seed", 0) + generator = None + if device.type != "mps": + generator = torch.Generator(device=device).manual_seed(int(seed)) + + output_ids = [] + for i in trange(T, disable=disable): + comfy.model_management.throw_exception_if_processing_interrupted() + curr_pos_id = torch.tensor([i], dtype=torch.long, device=device) + logits = cube(embed_buffer, cond, kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=(i > 0)) + logits = logits[:, 0, :num_codes] + + if use_cfg: + cond_logits, uncond_logits = logits.float().chunk(2, dim=0) + gamma = cfg * (T - i) / T + logits = (1.0 + gamma) * cond_logits - gamma * uncond_logits + else: + logits = logits.float() + + next_id = _cube_process_logits(logits, top_p, generator) + output_ids.append(next_id) + + next_embed = cube.encode_token(next_id) + if use_cfg: + next_embed = torch.cat([next_embed, next_embed], dim=0) + embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1)) + + if callback is not None: + callback({"x": x, "i": i, "sigma": sigmas[0], "sigma_hat": sigmas[0], "denoised": x}) + + return torch.cat(output_ids, dim=1).to(torch.float32) diff --git a/comfy/ldm/cube/gpt.py b/comfy/ldm/cube/gpt.py new file mode 100644 index 000000000..421648a8d --- /dev/null +++ b/comfy/ldm/cube/gpt.py @@ -0,0 +1,417 @@ +""" +Native port of Roblox/cube's shape GPT (DualStreamRoformer). + +Reference: https://github.com/Roblox/cube (cube3d/model/gpt/dual_stream_roformer.py +and cube3d/model/transformers/*). + +This is an autoregressive transformer over discrete VQ shape tokens, conditioned on +CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated +`sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler. + +The forward pass is kept faithful to upstream so token IDs match bit-for-bit: + * rope_theta = 10000 + * per-head RMSNorm on Q and K + * dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only + * two separate RoPE frequency tensors (dual blocks offset cond tokens by S) + * SwiGLU MLP, non-affine LayerNorm upcast to fp32 +""" + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# --------------------------------------------------------------------------- +# Norms (faithful to cube3d/model/transformers/norm.py) +# --------------------------------------------------------------------------- + +class CubeLayerNorm(nn.Module): + """Non-affine LayerNorm that upcasts to fp32 then back (matches cube).""" + + def __init__(self, dim, eps=1e-6): + super().__init__() + self.dim = (dim,) + self.eps = eps + + def forward(self, x): + y = F.layer_norm(x.float(), self.dim, None, None, self.eps) + return y.type_as(x) + + +class CubeRMSNorm(nn.Module): + """Per-head RMSNorm with learnable weight, computed in fp32 (matches cube).""" + + def __init__(self, dim, eps=1e-5, dtype=None, device=None): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device)) + + def forward(self, x): + xf = x.float() + out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + return (out * self.weight).type_as(x) + + +# --------------------------------------------------------------------------- +# RoPE (faithful to cube3d/model/transformers/rope.py) +# --------------------------------------------------------------------------- + +def apply_rotary_emb(x, freqs_cis, curr_pos_id=None): + x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + if curr_pos_id is None: + freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1) + else: + freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1) + y = torch.view_as_real(x_ * freqs_cis).flatten(3) + return y.type_as(x) + + +def precompute_freqs_cis(dim, t, theta=10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim)) + freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1) + return torch.polar(torch.ones_like(freqs), freqs) + + +def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False): + q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id) + k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None) + return F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0, + is_causal=is_causal and attn_mask is None, + ) + + +# --------------------------------------------------------------------------- +# KV cache +# --------------------------------------------------------------------------- + +class Cache: + def __init__(self, key_states, value_states): + self.key_states = key_states + self.value_states = value_states + + def update(self, curr_pos_id, k, v): + self.key_states.index_copy_(2, curr_pos_id, k) + self.value_states.index_copy_(2, curr_pos_id, v) + + +# --------------------------------------------------------------------------- +# Shared building blocks +# --------------------------------------------------------------------------- + +class SwiGLUMLP(nn.Module): + def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.gate_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device) + self.up_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device) + self.down_proj = operations.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device) + + def forward(self, x): + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class SelfAttentionWithRotaryEmbedding(nn.Module): + def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None): + super().__init__() + assert embed_dim % num_heads == 0 + self.num_heads = num_heads + head_dim = embed_dim // num_heads + self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device) + self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device) + self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device) + + def forward(self, x, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False): + b, l, d = x.shape + q, k = self.c_qk(x).chunk(2, dim=-1) + v = self.c_v(x) + q = q.view(b, l, self.num_heads, -1).transpose(1, 2) + k = k.view(b, l, self.num_heads, -1).transpose(1, 2) + v = v.view(b, l, self.num_heads, -1).transpose(1, 2) + q = self.q_norm(q) + k = self.k_norm(k) + if kv_cache is not None: + if not decode: + kv_cache.key_states[:, :, :k.shape[2], :].copy_(k) + kv_cache.value_states[:, :, :k.shape[2], :].copy_(v) + else: + kv_cache.update(curr_pos_id, k, v) + k = kv_cache.key_states + v = kv_cache.value_states + y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask, + curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal) + y = y.transpose(1, 2).contiguous().view(b, l, d) + return self.c_proj(y) + + +class DecoderLayerWithRotaryEmbedding(nn.Module): + """Single-stream decoder layer (shape tokens only).""" + + def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None): + super().__init__() + self.ln_1 = CubeLayerNorm(embed_dim, eps=eps) + self.attn = SelfAttentionWithRotaryEmbedding(embed_dim, num_heads, bias=bias, eps=eps, + dtype=dtype, device=device, operations=operations) + self.ln_2 = CubeLayerNorm(embed_dim, eps=eps) + self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations) + + def forward(self, x, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False): + x = x + self.attn(self.ln_1(x), freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal, + kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode) + x = x + self.mlp(self.ln_2(x)) + return x + + +# --------------------------------------------------------------------------- +# Dual-stream blocks (faithful to dual_stream_attention.py) +# --------------------------------------------------------------------------- + +class DismantledPreAttention(nn.Module): + def __init__(self, embed_dim, num_heads, query=True, bias=True, dtype=None, device=None, operations=None): + super().__init__() + assert embed_dim % num_heads == 0 + self.query = query + head_dim = embed_dim // num_heads + if query: + self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device) + self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device) + else: + self.c_k = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device) + self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.num_heads = num_heads + + def _to_mha(self, x): + return x.view(*x.shape[:2], self.num_heads, -1).transpose(1, 2) + + def forward(self, x): + if self.query: + q, k = self.c_qk(x).chunk(2, dim=-1) + q = self.q_norm(self._to_mha(q)) + else: + q = None + k = self.c_k(x) + k = self.k_norm(self._to_mha(k)) + v = self._to_mha(self.c_v(x)) + return (q, k, v) + + +class DismantledPostAttention(nn.Module): + def __init__(self, embed_dim, bias=True, eps=1e-6, dtype=None, device=None, operations=None): + super().__init__() + self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.ln_3 = CubeLayerNorm(embed_dim, eps=eps) + self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations) + + def forward(self, x, a): + x = x + self.c_proj(a) + x = x + self.mlp(self.ln_3(x)) + return x + + +class DualStreamAttentionWithRotaryEmbedding(nn.Module): + def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.cond_pre_only = cond_pre_only + self.pre_x = DismantledPreAttention(embed_dim, num_heads, query=True, bias=bias, + dtype=dtype, device=device, operations=operations) + self.pre_c = DismantledPreAttention(embed_dim, num_heads, query=not cond_pre_only, bias=bias, + dtype=dtype, device=device, operations=operations) + + def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False): + if kv_cache is None or not decode: + qkv_c = self.pre_c(c) + qkv_x = self.pre_x(x) + if self.cond_pre_only: + q = qkv_x[0] + else: + q = torch.cat([qkv_c[0], qkv_x[0]], dim=2) + k = torch.cat([qkv_c[1], qkv_x[1]], dim=2) + v = torch.cat([qkv_c[2], qkv_x[2]], dim=2) + else: + is_causal = False + q, k, v = self.pre_x(x) + + if kv_cache is not None: + if not decode: + kv_cache.key_states[:, :, :k.shape[2], :].copy_(k) + kv_cache.value_states[:, :, :k.shape[2], :].copy_(v) + else: + kv_cache.update(curr_pos_id, k, v) + k = kv_cache.key_states + v = kv_cache.value_states + + if attn_mask is not None: + if decode: + attn_mask = attn_mask[..., curr_pos_id, :] + else: + attn_mask = attn_mask[..., -q.shape[2]:, :] + + y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask, + curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal) + y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2]) + + if y.shape[1] == x.shape[1]: + return y, None + y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1) + return y_x, y_c + + +class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module): + def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, eps=1e-6, + dtype=None, device=None, operations=None): + super().__init__() + self.ln_1 = CubeLayerNorm(embed_dim, eps=eps) + self.ln_2 = CubeLayerNorm(embed_dim, eps=eps) + self.attn = DualStreamAttentionWithRotaryEmbedding(embed_dim, num_heads, cond_pre_only=cond_pre_only, + bias=bias, dtype=dtype, device=device, operations=operations) + self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations) + if not cond_pre_only: + self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations) + + def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False): + a_x, a_c = self.attn( + self.ln_1(x), + self.ln_2(c) if c is not None else None, + freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal, + kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode, + ) + x = self.post_1(x, a_x) + if a_c is not None: + c = self.post_2(c, a_c) + else: + c = None + return x, c + + +# --------------------------------------------------------------------------- +# DualStreamRoformer +# --------------------------------------------------------------------------- + +class DualStreamRoformer(nn.Module): + def __init__( + self, + n_layer=23, + n_single_layer=1, + rope_theta=10000, + n_head=12, + n_embd=1536, + bias=True, + eps=1e-6, + shape_model_vocab_size=16384, + shape_model_embed_dim=32, + text_model_embed_dim=768, + use_bbox=True, + image_model=None, # detection key; unused + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.n_layer = n_layer + self.n_single_layer = n_single_layer + self.n_head = n_head + self.n_embd = n_embd + self.rope_theta = rope_theta + self.head_dim = n_embd // n_head + + self.text_proj = operations.Linear(text_model_embed_dim, n_embd, bias=bias, dtype=dtype, device=device) + self.shape_proj = operations.Linear(shape_model_embed_dim, n_embd, bias=True, dtype=dtype, device=device) + + self.vocab_size = shape_model_vocab_size + self.shape_bos_id = self.vocab_size + self.shape_eos_id = self.vocab_size + 1 + self.padding_id = self.vocab_size + 2 + self.vocab_size += 3 + + self.transformer = nn.ModuleDict(dict( + wte=operations.Embedding(self.vocab_size, n_embd, padding_idx=self.padding_id, dtype=dtype, device=device), + dual_blocks=nn.ModuleList([ + DualStreamDecoderLayerWithRotaryEmbedding( + n_embd, n_head, cond_pre_only=(i == n_layer - 1), bias=bias, eps=eps, + dtype=dtype, device=device, operations=operations, + ) + for i in range(n_layer) + ]), + single_blocks=nn.ModuleList([ + DecoderLayerWithRotaryEmbedding(n_embd, n_head, bias=bias, eps=eps, + dtype=dtype, device=device, operations=operations) + for _ in range(n_single_layer) + ]), + ln_f=CubeLayerNorm(n_embd, eps=eps), + )) + + self.lm_head = operations.Linear(n_embd, self.vocab_size, bias=False, dtype=dtype, device=device) + + self.use_bbox = use_bbox + if use_bbox: + self.bbox_proj = operations.Linear(3, n_embd, bias=True, dtype=dtype, device=device) + + def encode_text(self, text_embed): + return self.text_proj(text_embed) + + def encode_token(self, tokens): + return self.transformer.wte(tokens) + + def init_kv_cache(self, batch_size, cond_len, max_shape_tokens, dtype, device): + max_all = cond_len + max_shape_tokens + kv = [ + Cache( + torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device), + torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device), + ) + for _ in range(len(self.transformer.dual_blocks)) + ] + kv += [ + Cache( + torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device), + torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device), + ) + for _ in range(len(self.transformer.single_blocks)) + ] + return kv + + def forward(self, embed, cond, kv_cache=None, curr_pos_id=None, decode=False): + b, l = embed.shape[:2] + s = cond.shape[1] + device = embed.device + + attn_mask = torch.tril(torch.ones(s + l, s + l, dtype=torch.bool, device=device)) + + position_ids = torch.arange(l, dtype=torch.long, device=device).unsqueeze(0).expand(b, -1) + s_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta) + + position_ids = torch.cat([ + torch.zeros([b, s], dtype=torch.long, device=device), + position_ids, + ], dim=1) + d_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta) + + if kv_cache is not None and decode: + embed = embed[:, curr_pos_id, :] + + h = embed + c = cond + layer_idx = 0 + for block in self.transformer.dual_blocks: + h, c = block( + h, c=c, freqs_cis=d_freqs_cis, attn_mask=attn_mask, is_causal=True, + kv_cache=kv_cache[layer_idx] if kv_cache is not None else None, + curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None, + decode=decode, + ) + layer_idx += 1 + for block in self.transformer.single_blocks: + h = block( + h, freqs_cis=s_freqs_cis, attn_mask=None, is_causal=True, + kv_cache=kv_cache[layer_idx] if kv_cache is not None else None, + curr_pos_id=curr_pos_id, decode=decode, + ) + layer_idx += 1 + + h = self.transformer.ln_f(h) + return self.lm_head(h) diff --git a/comfy/ldm/cube/vae.py b/comfy/ldm/cube/vae.py new file mode 100644 index 000000000..001741dac --- /dev/null +++ b/comfy/ldm/cube/vae.py @@ -0,0 +1,345 @@ +""" +Native port of Roblox/cube's shape tokenizer decode path (OneDAutoEncoder). + +Reference: https://github.com/Roblox/cube (cube3d/model/autoencoder/*). + +Only the DECODE path is ported (token IDs -> latents -> occupancy grid -> mesh); +the point-cloud encoder is not needed for text-to-3D generation. Encoder weights in +the checkpoint are loaded with strict=False and ignored. + +Module/parameter names mirror upstream so the checkpoint loads directly: + embedder.weight + bottleneck.block.{codebook, cb_weight, cb_bias, c_in, c_x, c_out, ...} + decoder.{positional_encodings, blocks.N...} + occupancy_decoder.{query_in, attn_out, ln_f, c_head} +""" + +import logging +import math +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +ops = comfy.ops.disable_weight_init + + +# --------------------------------------------------------------------------- +# Norms +# --------------------------------------------------------------------------- + +class CubeLayerNorm(nn.Module): + """LayerNorm upcasting to fp32. affine=False by default (no params).""" + + def __init__(self, dim, eps=1e-6, elementwise_affine=False, dtype=None, device=None): + super().__init__() + self.dim = (dim,) + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device)) + self.bias = nn.Parameter(torch.zeros(dim, dtype=dtype, device=device)) + else: + self.weight = None + self.bias = None + + def forward(self, x): + w = self.weight.float() if self.weight is not None else None + b = self.bias.float() if self.bias is not None else None + y = F.layer_norm(x.float(), self.dim, w, b, self.eps) + return y.type_as(x) + + +class CubeRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5, elementwise_affine=True, dtype=None, device=None): + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device)) + else: + self.register_buffer("weight", torch.ones(dim), persistent=False) + + def forward(self, x): + xf = x.float() + out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + return (out * self.weight.float()).type_as(x) + + +# --------------------------------------------------------------------------- +# Fourier embedder +# --------------------------------------------------------------------------- + +class PhaseModulatedFourierEmbedder(nn.Module): + def __init__(self, num_freqs, input_dim=3, dtype=None, device=None): + super().__init__() + self.weight = nn.Parameter(torch.empty(input_dim, num_freqs, dtype=dtype, device=device)) + carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs) + carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * math.pi + self.register_buffer("carrier", carrier, persistent=False) + self.out_dim = input_dim * (num_freqs * 2 + 1) + + def forward(self, x): + m = x.float().unsqueeze(-1) + w = self.weight.float() + carrier = self.carrier.float() + fm = (m * w).view(*x.shape[:-1], -1) + pm = (m * 0.5 * math.pi + carrier).view(*x.shape[:-1], -1) + return torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1).type_as(x) + + +# --------------------------------------------------------------------------- +# Attention building blocks +# --------------------------------------------------------------------------- + +class MLP(nn.Module): + def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None): + super().__init__() + self.up_proj = ops.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device) + self.down_proj = ops.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.act_fn = nn.GELU(approximate="none") + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +class SelfAttention(nn.Module): + def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None): + super().__init__() + assert embed_dim % num_heads == 0 + self.num_heads = num_heads + head_dim = embed_dim // num_heads + self.c_qk = ops.Linear(embed_dim, 2 * embed_dim, bias=bias, dtype=dtype, device=device) + self.c_v = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device) + self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device) + + def forward(self, x, attn_mask=None, is_causal=False): + b, l, d = x.shape + q, k = self.c_qk(x).chunk(2, dim=-1) + v = self.c_v(x) + q = self.q_norm(q.view(b, l, self.num_heads, -1).transpose(1, 2)) + k = self.k_norm(k.view(b, l, self.num_heads, -1).transpose(1, 2)) + v = v.view(b, l, self.num_heads, -1).transpose(1, 2) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, + is_causal=is_causal and attn_mask is None) + y = y.transpose(1, 2).contiguous().view(b, l, d) + return self.c_proj(y) + + +class CrossAttention(nn.Module): + def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, dtype=None, device=None): + super().__init__() + assert embed_dim % num_heads == 0 + q_dim = q_dim or embed_dim + kv_dim = kv_dim or embed_dim + self.c_q = ops.Linear(q_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.c_k = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.c_v = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.num_heads = num_heads + + def forward(self, x, c, attn_mask=None): + q, k, v = self.c_q(x), self.c_k(c), self.c_v(c) + b, l, d = q.shape + s = k.shape[1] + q = q.view(b, l, self.num_heads, -1).transpose(1, 2) + k = k.view(b, s, self.num_heads, -1).transpose(1, 2) + v = v.view(b, s, self.num_heads, -1).transpose(1, 2) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + y = y.transpose(1, 2).contiguous().view(b, l, d) + return self.c_proj(y) + + +class EncoderLayer(nn.Module): + def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None): + super().__init__() + self.ln_1 = CubeLayerNorm(embed_dim, eps=eps) + self.attn = SelfAttention(embed_dim, num_heads, bias=bias, eps=eps, dtype=dtype, device=device) + self.ln_2 = CubeLayerNorm(embed_dim, eps=eps) + self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device) + + def forward(self, x, attn_mask=None, is_causal=False): + x = x + self.attn(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal) + x = x + self.mlp(self.ln_2(x)) + return x + + +class EncoderCrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, eps=1e-6, dtype=None, device=None): + super().__init__() + q_dim = q_dim or embed_dim + kv_dim = kv_dim or embed_dim + self.attn = CrossAttention(embed_dim, num_heads, q_dim=q_dim, kv_dim=kv_dim, bias=bias, dtype=dtype, device=device) + self.ln_1 = CubeLayerNorm(q_dim, eps=eps) + self.ln_2 = CubeLayerNorm(kv_dim, eps=eps) + self.ln_f = CubeLayerNorm(embed_dim, eps=eps) + self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device) + + def forward(self, x, c, attn_mask=None): + x = x + self.attn(self.ln_1(x), self.ln_2(c), attn_mask=attn_mask) + x = x + self.mlp(self.ln_f(x)) + return x + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim, embed_dim, bias=True, dtype=None, device=None): + super().__init__() + self.in_layer = ops.Linear(in_dim, embed_dim, bias=bias, dtype=dtype, device=device) + self.silu = nn.SiLU() + self.out_layer = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device) + + def forward(self, x): + return self.out_layer(self.silu(self.in_layer(x))) + + +# --------------------------------------------------------------------------- +# Spherical VQ (decode-only parts) +# --------------------------------------------------------------------------- + +class SphericalVectorQuantizer(nn.Module): + def __init__(self, embed_dim, num_codes, width=None, dtype=None, device=None): + super().__init__() + self.num_codes = num_codes + self.codebook = ops.Embedding(num_codes, embed_dim, dtype=dtype, device=device) + width = width or embed_dim + if width != embed_dim: + self.c_in = ops.Linear(width, embed_dim, dtype=dtype, device=device) + self.c_x = ops.Linear(width, embed_dim, dtype=dtype, device=device) + self.c_out = ops.Linear(embed_dim, width, dtype=dtype, device=device) + else: + self.c_in = self.c_out = self.c_x = nn.Identity() + self.norm = CubeRMSNorm(embed_dim, elementwise_affine=False, dtype=dtype, device=device) + # "kl" codebook regularization (released config) + self.cb_weight = nn.Parameter(torch.ones([embed_dim], dtype=dtype, device=device)) + self.cb_bias = nn.Parameter(torch.zeros([embed_dim], dtype=dtype, device=device)) + + def cb_norm(self, x): + return x * self.cb_weight + self.cb_bias + + def get_codebook(self): + return self.norm(self.cb_norm(self.codebook.weight)) + + def lookup_codebook(self, q): + z_q = F.embedding(q, self.get_codebook()) + return self.c_out(z_q) + + +class OneDBottleNeck(nn.Module): + def __init__(self, block): + super().__init__() + self.block = block + + +# --------------------------------------------------------------------------- +# Decoders +# --------------------------------------------------------------------------- + +class OneDDecoder(nn.Module): + def __init__(self, num_latents, width, num_heads, num_layers, eps=1e-6, dtype=None, device=None): + super().__init__() + self.register_buffer("query", torch.empty([0, width]), persistent=False) + self.positional_encodings = nn.Parameter(torch.empty(num_latents, width, dtype=dtype, device=device)) + self.blocks = nn.ModuleList([ + EncoderLayer(width, num_heads, eps=eps, dtype=dtype, device=device) + for _ in range(num_layers) + ]) + + def forward(self, z): + h = z + self.positional_encodings[:z.shape[1]].unsqueeze(0).to(z.dtype) + for block in self.blocks: + h = block(h) + return h + + +class OneDOccupancyDecoder(nn.Module): + def __init__(self, embedder, out_features, width, num_heads, eps=1e-6, dtype=None, device=None): + super().__init__() + self.embedder = embedder + self.query_in = MLPEmbedder(embedder.out_dim, width, dtype=dtype, device=device) + self.attn_out = EncoderCrossAttentionLayer(width, num_heads, dtype=dtype, device=device) + self.ln_f = CubeLayerNorm(width, eps=eps, elementwise_affine=True, dtype=dtype, device=device) + self.c_head = ops.Linear(width, out_features, dtype=dtype, device=device) + + def forward(self, queries, latents): + x = self.query_in(self.embedder(queries)) + x = self.attn_out(x, latents) + return self.c_head(self.ln_f(x)) + + +# --------------------------------------------------------------------------- +# Top-level shape VAE +# --------------------------------------------------------------------------- + +def generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij"): + length = bbox_max - bbox_min + num_cells = np.exp2(resolution_base) + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + xs, ys, zs = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1).reshape(-1, 3) + grid_size = [int(num_cells) + 1] * 3 + return xyz, grid_size, length + + +class CubeShapeVAE(nn.Module): + """Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored).""" + + def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12, + num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6, + dtype=None, device=None): + super().__init__() + self.cfg_num_encoder_latents = num_encoder_latents + self.cfg_num_codes = num_codes + self.embedder = PhaseModulatedFourierEmbedder(num_freqs=num_freqs, input_dim=3, dtype=dtype, device=device) + self.bottleneck = OneDBottleNeck( + SphericalVectorQuantizer(embed_dim, num_codes, width, dtype=dtype, device=device) + ) + self.decoder = OneDDecoder(num_encoder_latents, width, num_heads, num_decoder_layers, + eps=eps, dtype=dtype, device=device) + self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads, + eps=eps, dtype=dtype, device=device) + + @torch.no_grad() + def decode_indices(self, shape_ids): + z_q = self.bottleneck.block.lookup_codebook(shape_ids) + return self.decoder(z_q) + + @torch.no_grad() + def query(self, queries, latents): + return self.occupancy_decoder(queries, latents).squeeze(-1) + + @torch.no_grad() + def extract_geometry(self, latents, bounds=(-1.05, -1.05, -1.05, 1.05, 1.05, 1.05), + resolution_base=8.0, chunk_size=100_000): + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz, grid_size, _ = generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij") + xyz = torch.from_numpy(xyz) + batch_size = latents.shape[0] + batch_logits = [] + for start in range(0, xyz.shape[0], chunk_size): + queries = xyz[start:start + chunk_size, :] + n = queries.shape[0] + if start > 0 and n < chunk_size: + queries = F.pad(queries, [0, 0, 0, chunk_size - n]) + bq = queries.unsqueeze(0).expand(batch_size, -1, -1).to(latents) + batch_logits.append(self.query(bq, latents)[:, :n]) + + grid_logits = torch.cat(batch_logits, dim=1).detach().view( + batch_size, grid_size[0], grid_size[1], grid_size[2]).float() + return grid_logits, grid_size, bbox_size, bbox_min + + +def grid_logits_to_mesh(grid_logit, grid_size, bbox_size, bbox_min, level=0.0): + """Marching cubes via skimage (matches upstream CPU fallback path).""" + from skimage import measure + vertices, faces, _, _ = measure.marching_cubes(grid_logit.cpu().numpy(), level, method="lewiner") + vertices = vertices / np.array(grid_size) * bbox_size + bbox_min + faces = faces[:, [2, 1, 0]] + return vertices.astype(np.float32), np.ascontiguousarray(faces) diff --git a/comfy/model_base.py b/comfy/model_base.py index ab4a11022..600bf204d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -44,6 +44,7 @@ import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.wan.model_animate import comfy.ldm.wan.ar_model +import comfy.ldm.cube.gpt import comfy.ldm.wan.model_wandancer import comfy.ldm.hunyuan3d.model import comfy.ldm.triposplat.model @@ -1903,6 +1904,26 @@ class Hunyuan3Dv2(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class Cube3D(BaseModel): + """Roblox Cube3D shape GPT (autoregressive). Generation goes through the + dedicated `cube` sampler (SamplerCustomAdvanced), never KSampler/apply_model.""" + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cube.gpt.DualStreamRoformer) + + def _apply_model(self, *args, **kwargs): + raise RuntimeError( + "Cube3D is an autoregressive token model. Use the 'cube' sampler " + "(SamplerCube + SamplerCustomAdvanced), not KSampler." + ) + + def extra_conds(self, **kwargs): + out = {} + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + + class Hunyuan3Dv2_1(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0cab308..0f7750f29 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -654,6 +654,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if '{}shape_proj.weight'.format(key_prefix) in state_dict_keys and '{}lm_head.weight'.format(key_prefix) in state_dict_keys: # Roblox Cube3D shape GPT + dit_config = {} + dit_config["image_model"] = "cube3d" + n_embd = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[1] + dit_config["n_embd"] = n_embd + dit_config["shape_model_vocab_size"] = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[0] - 3 + dit_config["n_layer"] = count_blocks(state_dict_keys, '{}transformer.dual_blocks.'.format(key_prefix) + '{}.') + dit_config["n_single_layer"] = count_blocks(state_dict_keys, '{}transformer.single_blocks.'.format(key_prefix) + '{}.') + head_dim = state_dict['{}transformer.dual_blocks.0.attn.pre_x.q_norm.weight'.format(key_prefix)].shape[0] + dit_config["n_head"] = n_embd // head_dim + dit_config["shape_model_embed_dim"] = state_dict['{}shape_proj.weight'.format(key_prefix)].shape[1] + dit_config["text_model_embed_dim"] = state_dict['{}text_proj.weight'.format(key_prefix)].shape[1] + dit_config["use_bbox"] = '{}bbox_proj.weight'.format(key_prefix) in state_dict_keys + dit_config["bias"] = '{}text_proj.bias'.format(key_prefix) in state_dict_keys + dit_config["rope_theta"] = 10000 + return dit_config + if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape dit_config = {} diff --git a/comfy/sd.py b/comfy/sd.py index a66ba1bfb..54ab5570d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae +import comfy.ldm.cube.vae import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae @@ -489,6 +490,7 @@ class VAE: self.disable_offload = False self.not_video = False self.size = None + self.cube3d = False self.downscale_index_formula = None self.upscale_index_formula = None @@ -777,6 +779,25 @@ class VAE: self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE() self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # Roblox Cube3D shape tokenizer (OneDAutoEncoder, decode-only) + elif "bottleneck.block.codebook.weight" in sd: + self.cube3d = True + self.latent_dim = 1 + embed_dim = sd["bottleneck.block.codebook.weight"].shape[1] + num_codes = sd["bottleneck.block.codebook.weight"].shape[0] + width = sd["bottleneck.block.c_out.weight"].shape[0] + num_encoder_latents = sd["decoder.positional_encodings"].shape[0] + head_dim = sd["decoder.blocks.0.attn.q_norm.weight"].shape[0] + num_heads = width // head_dim + num_freqs = sd["embedder.weight"].shape[1] + num_decoder_layers = len({k.split(".")[2] for k in sd if k.startswith("decoder.blocks.")}) + self.first_stage_model = comfy.ldm.cube.vae.CubeShapeVAE( + num_encoder_latents=num_encoder_latents, embed_dim=embed_dim, width=width, + num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers, + num_codes=num_codes, + ) + self.memory_used_decode = lambda shape, dtype: (1000 * shape[1] * 768) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.float32] elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3be935577..388e88a1a 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1550,6 +1550,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini + +class Cube3D(supported_models_base.BASE): + unet_config = { + "image_model": "cube3d", + } + + unet_extra_config = {} + + sampling_settings = {} + + latent_format = latent_formats.LatentFormat + + memory_usage_factor = 1.0 + + # Upstream keeps fp32 weights and uses bf16 autocast during the forward pass + # (see sample_cube). Prefer fp32 weights for parity; bf16 is the low-VRAM fallback. + supported_inference_dtypes = [torch.float32, torch.bfloat16] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Cube3D(self, device=device) + + def clip_target(self, state_dict={}): + return None + class TripoSplat(supported_models_base.BASE): # Image -> 3D gaussian splat flow denoiser unet_config = { @@ -2292,6 +2316,7 @@ models = [ Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, + Cube3D, TripoSplat, HiDream, HiDreamO1, diff --git a/comfy_extras/nodes_cube.py b/comfy_extras/nodes_cube.py new file mode 100644 index 000000000..625f09e13 --- /dev/null +++ b/comfy_extras/nodes_cube.py @@ -0,0 +1,153 @@ +""" +Nodes for native Roblox Cube3D text-to-3D support. + +Graph: + CLIPLoader(clip-l) -> CLIPTextEncode -> CONDITIONING + UNETLoader(shape_gpt) -> MODEL --\ + VAELoader(shape_tokenizer) -> VAE -> CubeCodebookPatch -> MODEL + CFGGuider(MODEL, pos, neg, cfg) + SamplerCube + (trivial sigmas) + EmptyCubeLatent + -> SamplerCustomAdvanced -> LATENT (token IDs) + VAEDecodeCube(VAE, LATENT) -> MESH -> SaveGLB +""" + +import numpy as np +import torch +from typing_extensions import override + +import comfy.ldm.cube.vae +import comfy.model_management +import comfy.samplers +from comfy_api.latest import ComfyExtension, IO, Types +from comfy_extras.nodes_save_3d import pack_variable_mesh_batch + + +class EmptyCubeLatent(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyCubeLatent", + category="latent/3d", + inputs=[ + IO.Int.Input("num_tokens", default=1024, min=1, max=8192, + tooltip="Shape token sequence length. Must match the tokenizer " + "(1024 for cube3d-v0.5, 512 for v0.1)."), + IO.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[IO.Latent.Output()], + ) + + @classmethod + def execute(cls, num_tokens, batch_size) -> IO.NodeOutput: + latent = torch.zeros([batch_size, num_tokens], device=comfy.model_management.intermediate_device()) + return IO.NodeOutput({"samples": latent, "type": "cube_tokens"}) + + +class CubeCodebookPatch(IO.ComfyNode): + """Inject the projected VQ codebook into the GPT token-embedding table. + + Upstream copies shape_proj(tokenizer.codebook) into wte.weight[:num_codes] at load + time; without it generation is garbage. Done here as a ModelPatcher object patch so + it composes with normal model loading/offload.""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="CubeCodebookPatch", + display_name="Cube Codebook Patch", + category="advanced/model", + inputs=[ + IO.Model.Input("model"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Model.Output()], + ) + + @classmethod + def execute(cls, model, vae) -> IO.NodeOutput: + gpt = model.get_model_object("diffusion_model") + codebook = vae.first_stage_model.bottleneck.block.get_codebook() # (num_codes, embed_dim) fp32 + w = gpt.shape_proj.weight + proj = gpt.shape_proj(codebook.to(device=w.device, dtype=w.dtype)) # (num_codes, n_embd) + + old = model.get_model_object("diffusion_model.transformer.wte.weight") + new = old.clone() + new[:proj.shape[0]] = proj.to(device=new.device, dtype=new.dtype) + + m = model.clone() + m.add_object_patch("diffusion_model.transformer.wte.weight", + torch.nn.Parameter(new, requires_grad=False)) + return IO.NodeOutput(m) + + +class SamplerCube(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SamplerCube", + display_name="Sampler Cube (autoregressive)", + category="sampling/custom_sampling/samplers", + inputs=[ + IO.Float.Input("top_p", default=1.0, min=0.0, max=1.0, step=0.01, + tooltip="1.0 = deterministic greedy (upstream default). " + "<1.0 enables nucleus sampling."), + ], + outputs=[IO.Sampler.Output()], + ) + + @classmethod + def execute(cls, top_p) -> IO.NodeOutput: + return IO.NodeOutput(comfy.samplers.ksampler("cube", {"top_p": top_p})) + + +class VAEDecodeCube(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeCube", + display_name="VAE Decode Cube (3D)", + category="latent/3d", + inputs=[ + IO.Vae.Input("vae"), + IO.Latent.Input("samples"), + IO.Float.Input("resolution_base", default=8.0, min=4.0, max=10.0, step=0.5, + tooltip="Grid cells per axis = 2^resolution_base. 8.0 matches " + "upstream default (257^3 grid)."), + IO.Int.Input("chunk_size", default=100000, min=1000, max=2000000, advanced=True), + ], + outputs=[IO.Mesh.Output()], + ) + + @classmethod + def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput: + comfy.model_management.load_models_gpu([vae.patcher]) + tok = vae.first_stage_model + ids = samples["samples"][:, :tok.cfg_num_encoder_latents].long() + ids = ids.clamp(0, tok.cfg_num_codes - 1).to(vae.device) + + latents = tok.decode_indices(ids) + grid, grid_size, bbox_size, bbox_min = tok.extract_geometry( + latents, resolution_base=resolution_base, chunk_size=chunk_size) + + verts_list, faces_list = [], [] + for i in range(grid.shape[0]): + v, f = comfy.ldm.cube.vae.grid_logits_to_mesh(grid[i], grid_size, bbox_size, bbox_min) + verts_list.append(torch.from_numpy(v)) + faces_list.append(torch.from_numpy(f.astype(np.int64))) + + mesh = pack_variable_mesh_batch(verts_list, faces_list) + return IO.NodeOutput(mesh) + + +class CubeExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyCubeLatent, + CubeCodebookPatch, + SamplerCube, + VAEDecodeCube, + ] + + +async def comfy_entrypoint() -> CubeExtension: + return CubeExtension() diff --git a/nodes.py b/nodes.py index 0d422d418..2069ec62c 100644 --- a/nodes.py +++ b/nodes.py @@ -2433,6 +2433,7 @@ async def init_builtin_extra_nodes(): "nodes_kandinsky5.py", "nodes_wanmove.py", "nodes_ar_video.py", + "nodes_cube.py", "nodes_image_compare.py", "nodes_zimage.py", "nodes_glsl.py",