mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
upscale node + simple node simplification
This commit is contained in:
parent
011f624dd5
commit
2d904b28da
@ -53,7 +53,6 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Voxel.Input("structure_output"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("resolution", options=["512", "1024"], default="512")
|
||||
],
|
||||
@ -64,7 +63,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, structure_output, vae, resolution):
|
||||
def execute(cls, samples, vae, resolution):
|
||||
|
||||
resolution = int(resolution)
|
||||
patcher = vae.patcher
|
||||
@ -72,8 +71,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
comfy.model_management.load_model_gpu(patcher)
|
||||
|
||||
vae = vae.first_stage_model
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
coords = samples["coords"]
|
||||
|
||||
samples = samples["samples"]
|
||||
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||
@ -93,7 +91,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Voxel.Input("structure_output"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.AnyType.Input("shape_subs"),
|
||||
],
|
||||
@ -103,15 +100,15 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, structure_output, vae, shape_subs):
|
||||
def execute(cls, samples, vae, shape_subs):
|
||||
|
||||
patcher = vae.patcher
|
||||
device = comfy.model_management.get_torch_device()
|
||||
comfy.model_management.load_model_gpu(patcher)
|
||||
|
||||
vae = vae.first_stage_model
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
coords = samples["coords"]
|
||||
|
||||
samples = samples["samples"]
|
||||
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
@ -161,6 +158,56 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class Trellis2UpsampleCascade(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Trellis2UpsampleCascade",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("shape_latent_512"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024"),
|
||||
IO.Int.Input("max_tokens", default=49152, min=1024, max=100000)
|
||||
],
|
||||
outputs=[
|
||||
IO.AnyType.Output("hr_coords"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, shape_latent_512, vae, target_resolution, max_tokens):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||
coords_512 = shape_latent_512["coords"].to(device)
|
||||
|
||||
slat = shape_norm(feats, coords_512)
|
||||
|
||||
decoder = vae.first_stage_model.shape_dec
|
||||
decoder.to(device)
|
||||
|
||||
slat.feats = slat.feats.to(next(decoder.parameters()).dtype)
|
||||
hr_coords = decoder.upsample(slat, upsample_times=4)
|
||||
decoder.cpu()
|
||||
|
||||
lr_resolution = 512
|
||||
hr_resolution = int(target_resolution)
|
||||
|
||||
while True:
|
||||
quant_coords = torch.cat([
|
||||
hr_coords[:, :1],
|
||||
((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
|
||||
], dim=1)
|
||||
final_coords = quant_coords.unique(dim=0)
|
||||
num_tokens = final_coords.shape[0]
|
||||
|
||||
if num_tokens < max_tokens or hr_resolution <= 1024:
|
||||
break
|
||||
hr_resolution -= 128
|
||||
|
||||
return IO.NodeOutput(final_coords.cpu())
|
||||
|
||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
||||
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
||||
|
||||
@ -282,7 +329,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
node_id="EmptyShapeLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("structure_output"),
|
||||
IO.AnyType.Input("structure_or_coords"),
|
||||
IO.Model.Input("model")
|
||||
],
|
||||
outputs=[
|
||||
@ -292,9 +339,13 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, structure_output, model):
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
def execute(cls, structure_or_coords, model):
|
||||
# to accept the upscaled coords
|
||||
if hasattr(structure_or_coords, "data"):
|
||||
decoded = structure_or_coords.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
else:
|
||||
coords = structure_or_coords
|
||||
in_channels = 32
|
||||
# image like format
|
||||
latent = torch.randn(1, in_channels, coords.shape[0], 1)
|
||||
@ -307,7 +358,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
|
||||
model.model_options["transformer_options"]["coords"] = coords
|
||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
|
||||
return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
|
||||
|
||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -556,6 +607,7 @@ class Trellis2Extension(ComfyExtension):
|
||||
VaeDecodeTextureTrellis,
|
||||
VaeDecodeShapeTrellis,
|
||||
VaeDecodeStructureTrellis2,
|
||||
Trellis2UpsampleCascade,
|
||||
PostProcessMesh
|
||||
]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user