This commit is contained in:
Yousef Rafat 2026-02-13 21:05:59 +02:00
parent 0e51bee64f
commit 92aa058a58
3 changed files with 16 additions and 13 deletions

View File

@ -228,6 +228,8 @@ class DINOv3ViTLayer(nn.Module):
class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
if dtype == torch.float16:
dtype = torch.bfloat16
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
num_attention_heads = config["num_attention_heads"]

View File

@ -678,10 +678,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
return x
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False)
else:
return self._forward(x, mod, context, phases)
return self._forward(x, mod, context, phases)
class SparseStructureFlowModel(nn.Module):
@ -823,18 +820,25 @@ class Trellis2(nn.Module):
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
args.pop("out_channels")
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
self.guidance_interval = [0.6, 1.0]
self.guidance_interval_txt = [0.6, 0.9]
def forward(self, x, timestep, context, **kwargs):
embeds = kwargs.get("embeds")
mode = x.generation_mode
mode = kwargs.get("generation_mode")
sigmas = kwargs.get("sigmas")[0].item()
cond = context.chunk(2)
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
if mode == "shape_generation":
# TODO
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
elif mode == "texture_generation":
out = self.shape2txt(x, timestep, context)
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
else: # structure
timestep = timestep_reshift(timestep)
out = self.structure_model(x, timestep, context)
out = self.structure_model(x, timestep, context if not shape_rule else cond)
out.generation_mode = mode
return out

View File

@ -250,8 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
in_channels = 32
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
latent.generation_mode = "shape_generation"
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"})
class EmptyTextureLatentTrellis2(IO.ComfyNode):
@classmethod
@ -272,8 +271,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
# TODO
in_channels = 32
latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
latent.generation_mode = "texture_generation"
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"})
class EmptyStructureLatentTrellis2(IO.ComfyNode):
@classmethod
@ -293,8 +291,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
in_channels = 8
resolution = 16
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
latent.generation_mode = "structure_generation"
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"})
def simplify_fn(vertices, faces, target=100000):