mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
.
This commit is contained in:
parent
0e51bee64f
commit
92aa058a58
@ -228,6 +228,8 @@ class DINOv3ViTLayer(nn.Module):
|
||||
class DINOv3ViTModel(nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
if dtype == torch.float16:
|
||||
dtype = torch.bfloat16
|
||||
num_hidden_layers = config["num_hidden_layers"]
|
||||
hidden_size = config["hidden_size"]
|
||||
num_attention_heads = config["num_attention_heads"]
|
||||
|
||||
@ -678,10 +678,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, mod, context, phases)
|
||||
return self._forward(x, mod, context, phases)
|
||||
|
||||
|
||||
class SparseStructureFlowModel(nn.Module):
|
||||
@ -823,18 +820,25 @@ class Trellis2(nn.Module):
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
args.pop("out_channels")
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||
self.guidance_interval = [0.6, 1.0]
|
||||
self.guidance_interval_txt = [0.6, 0.9]
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
embeds = kwargs.get("embeds")
|
||||
mode = x.generation_mode
|
||||
mode = kwargs.get("generation_mode")
|
||||
sigmas = kwargs.get("sigmas")[0].item()
|
||||
cond = context.chunk(2)
|
||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||
|
||||
if mode == "shape_generation":
|
||||
# TODO
|
||||
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
||||
elif mode == "texture_generation":
|
||||
out = self.shape2txt(x, timestep, context)
|
||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||
else: # structure
|
||||
timestep = timestep_reshift(timestep)
|
||||
out = self.structure_model(x, timestep, context)
|
||||
out = self.structure_model(x, timestep, context if not shape_rule else cond)
|
||||
|
||||
out.generation_mode = mode
|
||||
return out
|
||||
|
||||
@ -250,8 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
|
||||
in_channels = 32
|
||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||
latent.generation_mode = "shape_generation"
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"})
|
||||
|
||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -272,8 +271,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
# TODO
|
||||
in_channels = 32
|
||||
latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||
latent.generation_mode = "texture_generation"
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"})
|
||||
|
||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -293,8 +291,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
in_channels = 8
|
||||
resolution = 16
|
||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||
latent.generation_mode = "structure_generation"
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"})
|
||||
|
||||
def simplify_fn(vertices, faces, target=100000):
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user