coderabbit 2

This commit is contained in:
Yousef Rafat 2026-02-20 21:13:13 +02:00
parent f3d4125e49
commit b3da8ed4c5
3 changed files with 37 additions and 16 deletions

View File

@ -207,11 +207,14 @@ class TorchHashMap:
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
flat = flat_keys.long() flat = flat_keys.long()
if self._n == 0:
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
idx = torch.searchsorted(self.sorted_keys, flat) idx = torch.searchsorted(self.sorted_keys, flat)
found = (idx < self._n) & (self.sorted_keys[idx] == flat) idx_safe = torch.clamp(idx, max=self._n - 1)
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
if found.any(): if found.any():
out[found] = self.sorted_vals[idx[found]] out[found] = self.sorted_vals[idx_safe[found]]
return out return out

View File

@ -8,7 +8,6 @@ from comfy.ldm.trellis2.attention import (
) )
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
from comfy.ldm.flux.math import apply_rope, apply_rope1 from comfy.ldm.flux.math import apply_rope, apply_rope1
import builtins
class SparseGELU(nn.GELU): class SparseGELU(nn.GELU):
def forward(self, input: VarLenTensor) -> VarLenTensor: def forward(self, input: VarLenTensor) -> VarLenTensor:
@ -829,19 +828,19 @@ class Trellis2(nn.Module):
def forward(self, x, timestep, context, **kwargs): def forward(self, x, timestep, context, **kwargs):
# FIXME: should find a way to distinguish between 512/1024 models # FIXME: should find a way to distinguish between 512/1024 models
# currently assumes 1024 # currently assumes 1024
transformer_options = kwargs.get("transformer_options")
embeds = kwargs.get("embeds") embeds = kwargs.get("embeds")
_, cond = context.chunk(2) #_, cond = context.chunk(2) # TODO
cond = embeds.chunk(2)[0] cond = embeds.chunk(2)[0]
context = torch.cat([torch.zeros_like(cond), cond]) context = torch.cat([torch.zeros_like(cond), cond])
mode = getattr(builtins, "TRELLIS_MODE", "structure_generation") coords = transformer_options.get("coords", None)
coords = getattr(builtins, "TRELLIS_COORDS", None) mode = transformer_options.get("generation_mode", "structure_generation")
if coords is not None: if coords is not None:
x = x.squeeze(0) x = x.squeeze(0)
not_struct_mode = True not_struct_mode = True
else: else:
mode = "structure_generation" mode = "structure_generation"
not_struct_mode = False not_struct_mode = False
transformer_options = kwargs.get("transformer_options")
sigmas = transformer_options.get("sigmas")[0].item() sigmas = transformer_options.get("sigmas")[0].item()
if sigmas < 1.00001: if sigmas < 1.00001:
timestep *= 1000.0 timestep *= 1000.0

View File

@ -4,7 +4,6 @@ import torch
import comfy.model_management import comfy.model_management
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import builtins
shape_slat_normalization = { shape_slat_normalization = {
"mean": torch.tensor([ "mean": torch.tensor([
@ -258,21 +257,31 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
category="latent/3d", category="latent/3d",
inputs=[ inputs=[
IO.Voxel.Input("structure_output"), IO.Voxel.Input("structure_output"),
IO.Model.Input("model")
], ],
outputs=[ outputs=[
IO.Latent.Output(), IO.Latent.Output(),
IO.Model.Output()
] ]
) )
@classmethod @classmethod
def execute(cls, structure_output): def execute(cls, structure_output, model):
decoded = structure_output.data.unsqueeze(1) decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
in_channels = 32 in_channels = 32
latent = torch.randn(1, coords.shape[0], in_channels) latent = torch.randn(1, coords.shape[0], in_channels)
builtins.TRELLIS_MODE = "shape_generation" model = model.clone()
builtins.TRELLIS_COORDS = coords if "transformer_options" not in model.model_options:
return IO.NodeOutput({"samples": latent, "type": "trellis2"}) model.model_options = {}
else:
model.model_options = model.model_options.copy()
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
class EmptyTextureLatentTrellis2(IO.ComfyNode): class EmptyTextureLatentTrellis2(IO.ComfyNode):
@classmethod @classmethod
@ -285,19 +294,29 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
], ],
outputs=[ outputs=[
IO.Latent.Output(), IO.Latent.Output(),
IO.Model.Output()
] ]
) )
@classmethod @classmethod
def execute(cls, structure_output): def execute(cls, structure_output, model):
# TODO # TODO
decoded = structure_output.data.unsqueeze(1) decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
in_channels = 32 in_channels = 32
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
builtins.TRELLIS_MODE = "texture_generation" model = model.clone()
builtins.TRELLIS_COORDS = coords if "transformer_options" not in model.model_options:
return IO.NodeOutput({"samples": latent, "type": "trellis2"}) model.model_options = {}
else:
model.model_options = model.model_options.copy()
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
class EmptyStructureLatentTrellis2(IO.ComfyNode): class EmptyStructureLatentTrellis2(IO.ComfyNode):
@classmethod @classmethod