diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 9cb231e28..ce6b2edd9 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -228,6 +228,8 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): def __init__(self, config, dtype, device, operations): super().__init__() + if dtype == torch.float16: + dtype = torch.bfloat16 num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] num_attention_heads = config["num_attention_heads"] diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 760372f5c..76fe8ad19 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -678,10 +678,7 @@ class ModulatedTransformerCrossBlock(nn.Module): return x def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) - else: - return self._forward(x, mod, context, phases) + return self._forward(x, mod, context, phases) class SparseStructureFlowModel(nn.Module): @@ -823,18 +820,25 @@ class Trellis2(nn.Module): self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) + self.guidance_interval = [0.6, 1.0] + self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") - mode = x.generation_mode + mode = kwargs.get("generation_mode") + sigmas = kwargs.get("sigmas")[0].item() + cond = context.chunk(2) + shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] + txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] + if mode == "shape_generation": # TODO out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) elif mode == "texture_generation": - out = self.shape2txt(x, timestep, context) + out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) - out = self.structure_model(x, timestep, context) + out = self.structure_model(x, timestep, context if not shape_rule else cond) out.generation_mode = mode return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index a5c387c1d..560751091 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -250,8 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) - latent.generation_mode = "shape_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -272,8 +271,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) - latent.generation_mode = "texture_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod @@ -293,8 +291,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - latent.generation_mode = "structure_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"}) def simplify_fn(vertices, faces, target=100000):