mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
resolution logit
This commit is contained in:
parent
39270fdca9
commit
7d444a4fcc
@ -812,7 +812,7 @@ class Trellis2(nn.Module):
|
||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||
else: # structure
|
||||
timestep = timestep_reshift(timestep)
|
||||
#timestep = timestep_reshift(timestep)
|
||||
orig_bsz = x.shape[0]
|
||||
if shape_rule:
|
||||
x = x[0].unsqueeze(0)
|
||||
|
||||
@ -136,7 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Voxel.Input("structure_output"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Int.Input("resolution", tooltip="Shape Generation Resolution"),
|
||||
IO.Combo.Input("resolution", options=["512", "1024"], default="512")
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("mesh"),
|
||||
@ -147,6 +147,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, samples, structure_output, vae, resolution):
|
||||
|
||||
resolution = int(resolution)
|
||||
patcher = vae.patcher
|
||||
device = comfy.model_management.get_torch_device()
|
||||
comfy.model_management.load_model_gpu(patcher)
|
||||
@ -154,14 +155,18 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
vae = vae.first_stage_model
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
|
||||
samples = samples["samples"]
|
||||
samples = samples.squeeze(-1).transpose(1, 2).to(device)
|
||||
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||
std = shape_slat_normalization["std"].to(samples)
|
||||
mean = shape_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords)
|
||||
samples = samples * std + mean
|
||||
|
||||
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
||||
faces = torch.stack([m.faces for m in mesh])
|
||||
verts = torch.stack([m.vertices for m in mesh])
|
||||
mesh = Types.MESH(vertices=verts, faces=faces)
|
||||
return IO.NodeOutput(mesh, subs)
|
||||
|
||||
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
@ -192,13 +197,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
samples = samples["samples"]
|
||||
samples = samples.squeeze(-1).transpose(1, 2).to(device)
|
||||
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
mean = tex_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords)
|
||||
samples = samples * std + mean
|
||||
|
||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||
faces = torch.stack([m.faces for m in mesh])
|
||||
verts = torch.stack([m.vertices for m in mesh])
|
||||
mesh = Types.MESH(vertices=verts, faces=faces)
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
@ -210,6 +218,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("resolution", options=["32", "64"], default="32")
|
||||
],
|
||||
outputs=[
|
||||
IO.Voxel.Output("structure_output"),
|
||||
@ -217,7 +226,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae):
|
||||
def execute(cls, samples, vae, resolution):
|
||||
resolution = int(resolution)
|
||||
vae = vae.first_stage_model
|
||||
decoder = vae.struct_dec
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
@ -227,6 +237,11 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
samples = samples.to(load_device)
|
||||
decoded = decoder(samples)>0
|
||||
decoder.to(offload_device)
|
||||
current_res = decoded.shape[2]
|
||||
|
||||
if current_res != resolution:
|
||||
ratio = current_res // resolution
|
||||
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
|
||||
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user