mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
work on txt gen
This commit is contained in:
parent
def8947e75
commit
56e52e5d03
@ -754,6 +754,7 @@ class Trellis2(nn.Module):
|
||||
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
||||
}
|
||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||
self.shape2txt = None
|
||||
if init_txt_model:
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
|
||||
@ -835,11 +836,12 @@ class Trellis2(nn.Module):
|
||||
slat = transformer_options.get("shape_slat")
|
||||
if slat is None:
|
||||
raise ValueError("shape_slat can't be None")
|
||||
slat.feats = slat.feats.repeat(B, 1)
|
||||
x_st = sparse_cat([x_st, slat])
|
||||
|
||||
base_slat_feats = slat.feats[:N]
|
||||
slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device)
|
||||
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1))
|
||||
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||
else: # structure
|
||||
#timestep = timestep_reshift(timestep)
|
||||
orig_bsz = x.shape[0]
|
||||
if shape_rule:
|
||||
x = x[0].unsqueeze(0)
|
||||
@ -850,8 +852,7 @@ class Trellis2(nn.Module):
|
||||
|
||||
if not_struct_mode:
|
||||
out = out.feats
|
||||
if not_struct_mode:
|
||||
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
||||
if rule and orig_bsz > 1:
|
||||
out = out.repeat(orig_bsz, 1, 1, 1)
|
||||
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
||||
if rule and orig_bsz > 1:
|
||||
out = out.repeat(orig_bsz, 1, 1, 1)
|
||||
return out
|
||||
|
||||
@ -95,7 +95,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
IO.AnyType.Input("shape_subs"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("mesh"),
|
||||
IO.Voxel.Output("voxel"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -116,11 +116,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
samples = SparseTensor(feats = samples, coords=coords)
|
||||
samples = samples * std + mean
|
||||
|
||||
mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5
|
||||
faces = torch.stack([m.faces for m in mesh])
|
||||
verts = torch.stack([m.vertices for m in mesh])
|
||||
mesh = Types.MESH(vertices=verts, faces=faces)
|
||||
return IO.NodeOutput(mesh)
|
||||
voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5
|
||||
voxel = Types.VOXEL(voxel)
|
||||
return IO.NodeOutput(voxel)
|
||||
|
||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -377,7 +375,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
node_id="EmptyTextureLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("structure_output"),
|
||||
IO.Voxel.Input("structure_or_coords"),
|
||||
IO.Latent.Input("shape_latent"),
|
||||
IO.Model.Input("model")
|
||||
],
|
||||
@ -388,16 +386,21 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, structure_output, shape_latent, model):
|
||||
# TODO
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
in_channels = 32
|
||||
def execute(cls, structure_or_coords, shape_latent, model):
|
||||
channels = 32
|
||||
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
|
||||
decoded = structure_or_coords.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
|
||||
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
|
||||
coords = structure_or_coords.int()
|
||||
|
||||
shape_latent = shape_latent["samples"]
|
||||
if shape_latent.ndim == 4:
|
||||
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||
shape_latent = shape_norm(shape_latent, coords)
|
||||
|
||||
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
||||
latent = torch.randn(1, channels, coords.shape[0], 1)
|
||||
model = model.clone()
|
||||
model.model_options = model.model_options.copy()
|
||||
if "transformer_options" in model.model_options:
|
||||
@ -406,9 +409,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
model.model_options["transformer_options"] = {}
|
||||
|
||||
model.model_options["transformer_options"]["coords"] = coords
|
||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
|
||||
model.model_options["transformer_options"]["shape_slat"] = shape_latent
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
|
||||
return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
|
||||
|
||||
|
||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user