mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-18 22:42:35 +08:00
resolution logit
This commit is contained in:
parent
39270fdca9
commit
7d444a4fcc
@ -812,7 +812,7 @@ class Trellis2(nn.Module):
|
|||||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||||
else: # structure
|
else: # structure
|
||||||
timestep = timestep_reshift(timestep)
|
#timestep = timestep_reshift(timestep)
|
||||||
orig_bsz = x.shape[0]
|
orig_bsz = x.shape[0]
|
||||||
if shape_rule:
|
if shape_rule:
|
||||||
x = x[0].unsqueeze(0)
|
x = x[0].unsqueeze(0)
|
||||||
|
|||||||
@ -136,7 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
IO.Latent.Input("samples"),
|
IO.Latent.Input("samples"),
|
||||||
IO.Voxel.Input("structure_output"),
|
IO.Voxel.Input("structure_output"),
|
||||||
IO.Vae.Input("vae"),
|
IO.Vae.Input("vae"),
|
||||||
IO.Int.Input("resolution", tooltip="Shape Generation Resolution"),
|
IO.Combo.Input("resolution", options=["512", "1024"], default="512")
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Mesh.Output("mesh"),
|
IO.Mesh.Output("mesh"),
|
||||||
@ -147,6 +147,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, structure_output, vae, resolution):
|
def execute(cls, samples, structure_output, vae, resolution):
|
||||||
|
|
||||||
|
resolution = int(resolution)
|
||||||
patcher = vae.patcher
|
patcher = vae.patcher
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
comfy.model_management.load_model_gpu(patcher)
|
comfy.model_management.load_model_gpu(patcher)
|
||||||
@ -154,14 +155,18 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
decoded = structure_output.data.unsqueeze(1)
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
|
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples = samples.squeeze(-1).transpose(1, 2).to(device)
|
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||||
std = shape_slat_normalization["std"].to(samples)
|
std = shape_slat_normalization["std"].to(samples)
|
||||||
mean = shape_slat_normalization["mean"].to(samples)
|
mean = shape_slat_normalization["mean"].to(samples)
|
||||||
samples = SparseTensor(feats = samples, coords=coords)
|
samples = SparseTensor(feats = samples, coords=coords)
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
||||||
|
faces = torch.stack([m.faces for m in mesh])
|
||||||
|
verts = torch.stack([m.vertices for m in mesh])
|
||||||
|
mesh = Types.MESH(vertices=verts, faces=faces)
|
||||||
return IO.NodeOutput(mesh, subs)
|
return IO.NodeOutput(mesh, subs)
|
||||||
|
|
||||||
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||||
@ -192,13 +197,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
decoded = structure_output.data.unsqueeze(1)
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples = samples.squeeze(-1).transpose(1, 2).to(device)
|
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||||
std = tex_slat_normalization["std"].to(samples)
|
std = tex_slat_normalization["std"].to(samples)
|
||||||
mean = tex_slat_normalization["mean"].to(samples)
|
mean = tex_slat_normalization["mean"].to(samples)
|
||||||
samples = SparseTensor(feats = samples, coords=coords)
|
samples = SparseTensor(feats = samples, coords=coords)
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
mesh = vae.decode_tex_slat(samples, shape_subs)
|
mesh = vae.decode_tex_slat(samples, shape_subs)
|
||||||
|
faces = torch.stack([m.faces for m in mesh])
|
||||||
|
verts = torch.stack([m.vertices for m in mesh])
|
||||||
|
mesh = Types.MESH(vertices=verts, faces=faces)
|
||||||
return IO.NodeOutput(mesh)
|
return IO.NodeOutput(mesh)
|
||||||
|
|
||||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||||
@ -210,6 +218,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.Latent.Input("samples"),
|
IO.Latent.Input("samples"),
|
||||||
IO.Vae.Input("vae"),
|
IO.Vae.Input("vae"),
|
||||||
|
IO.Combo.Input("resolution", options=["32", "64"], default="32")
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Voxel.Output("structure_output"),
|
IO.Voxel.Output("structure_output"),
|
||||||
@ -217,7 +226,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae):
|
def execute(cls, samples, vae, resolution):
|
||||||
|
resolution = int(resolution)
|
||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
decoder = vae.struct_dec
|
decoder = vae.struct_dec
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
@ -227,6 +237,11 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
samples = samples.to(load_device)
|
samples = samples.to(load_device)
|
||||||
decoded = decoder(samples)>0
|
decoded = decoder(samples)>0
|
||||||
decoder.to(offload_device)
|
decoder.to(offload_device)
|
||||||
|
current_res = decoded.shape[2]
|
||||||
|
|
||||||
|
if current_res != resolution:
|
||||||
|
ratio = current_res // resolution
|
||||||
|
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
|
||||||
out = Types.VOXEL(decoded.squeeze(1).float())
|
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||||
return IO.NodeOutput(out)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user