mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
.
This commit is contained in:
parent
0e51bee64f
commit
92aa058a58
@ -228,6 +228,8 @@ class DINOv3ViTLayer(nn.Module):
|
|||||||
class DINOv3ViTModel(nn.Module):
|
class DINOv3ViTModel(nn.Module):
|
||||||
def __init__(self, config, dtype, device, operations):
|
def __init__(self, config, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if dtype == torch.float16:
|
||||||
|
dtype = torch.bfloat16
|
||||||
num_hidden_layers = config["num_hidden_layers"]
|
num_hidden_layers = config["num_hidden_layers"]
|
||||||
hidden_size = config["hidden_size"]
|
hidden_size = config["hidden_size"]
|
||||||
num_attention_heads = config["num_attention_heads"]
|
num_attention_heads = config["num_attention_heads"]
|
||||||
|
|||||||
@ -678,10 +678,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
if self.use_checkpoint:
|
return self._forward(x, mod, context, phases)
|
||||||
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False)
|
|
||||||
else:
|
|
||||||
return self._forward(x, mod, context, phases)
|
|
||||||
|
|
||||||
|
|
||||||
class SparseStructureFlowModel(nn.Module):
|
class SparseStructureFlowModel(nn.Module):
|
||||||
@ -823,18 +820,25 @@ class Trellis2(nn.Module):
|
|||||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||||
args.pop("out_channels")
|
args.pop("out_channels")
|
||||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||||
|
self.guidance_interval = [0.6, 1.0]
|
||||||
|
self.guidance_interval_txt = [0.6, 0.9]
|
||||||
|
|
||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
mode = x.generation_mode
|
mode = kwargs.get("generation_mode")
|
||||||
|
sigmas = kwargs.get("sigmas")[0].item()
|
||||||
|
cond = context.chunk(2)
|
||||||
|
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||||
|
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||||
|
|
||||||
if mode == "shape_generation":
|
if mode == "shape_generation":
|
||||||
# TODO
|
# TODO
|
||||||
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
||||||
elif mode == "texture_generation":
|
elif mode == "texture_generation":
|
||||||
out = self.shape2txt(x, timestep, context)
|
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||||
else: # structure
|
else: # structure
|
||||||
timestep = timestep_reshift(timestep)
|
timestep = timestep_reshift(timestep)
|
||||||
out = self.structure_model(x, timestep, context)
|
out = self.structure_model(x, timestep, context if not shape_rule else cond)
|
||||||
|
|
||||||
out.generation_mode = mode
|
out.generation_mode = mode
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -250,8 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||||
latent.generation_mode = "shape_generation"
|
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"})
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
|
||||||
|
|
||||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -272,8 +271,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
# TODO
|
# TODO
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
|
latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||||
latent.generation_mode = "texture_generation"
|
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"})
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -293,8 +291,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
in_channels = 8
|
in_channels = 8
|
||||||
resolution = 16
|
resolution = 16
|
||||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||||
latent.generation_mode = "structure_generation"
|
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"})
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
|
||||||
|
|
||||||
def simplify_fn(vertices, faces, target=100000):
|
def simplify_fn(vertices, faces, target=100000):
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user