mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
work on txt gen
This commit is contained in:
parent
def8947e75
commit
56e52e5d03
@ -754,6 +754,7 @@ class Trellis2(nn.Module):
|
|||||||
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
||||||
}
|
}
|
||||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||||
|
self.shape2txt = None
|
||||||
if init_txt_model:
|
if init_txt_model:
|
||||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||||
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
|
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
|
||||||
@ -835,11 +836,12 @@ class Trellis2(nn.Module):
|
|||||||
slat = transformer_options.get("shape_slat")
|
slat = transformer_options.get("shape_slat")
|
||||||
if slat is None:
|
if slat is None:
|
||||||
raise ValueError("shape_slat can't be None")
|
raise ValueError("shape_slat can't be None")
|
||||||
slat.feats = slat.feats.repeat(B, 1)
|
|
||||||
x_st = sparse_cat([x_st, slat])
|
base_slat_feats = slat.feats[:N]
|
||||||
|
slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device)
|
||||||
|
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1))
|
||||||
out = self.shape2txt(x_st, t_eval, c_eval)
|
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||||
else: # structure
|
else: # structure
|
||||||
#timestep = timestep_reshift(timestep)
|
|
||||||
orig_bsz = x.shape[0]
|
orig_bsz = x.shape[0]
|
||||||
if shape_rule:
|
if shape_rule:
|
||||||
x = x[0].unsqueeze(0)
|
x = x[0].unsqueeze(0)
|
||||||
@ -850,8 +852,7 @@ class Trellis2(nn.Module):
|
|||||||
|
|
||||||
if not_struct_mode:
|
if not_struct_mode:
|
||||||
out = out.feats
|
out = out.feats
|
||||||
if not_struct_mode:
|
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
||||||
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
if rule and orig_bsz > 1:
|
||||||
if rule and orig_bsz > 1:
|
out = out.repeat(orig_bsz, 1, 1, 1)
|
||||||
out = out.repeat(orig_bsz, 1, 1, 1)
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -95,7 +95,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
IO.AnyType.Input("shape_subs"),
|
IO.AnyType.Input("shape_subs"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Mesh.Output("mesh"),
|
IO.Voxel.Output("voxel"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -116,11 +116,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
samples = SparseTensor(feats = samples, coords=coords)
|
samples = SparseTensor(feats = samples, coords=coords)
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5
|
voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5
|
||||||
faces = torch.stack([m.faces for m in mesh])
|
voxel = Types.VOXEL(voxel)
|
||||||
verts = torch.stack([m.vertices for m in mesh])
|
return IO.NodeOutput(voxel)
|
||||||
mesh = Types.MESH(vertices=verts, faces=faces)
|
|
||||||
return IO.NodeOutput(mesh)
|
|
||||||
|
|
||||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -377,7 +375,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
node_id="EmptyTextureLatentTrellis2",
|
node_id="EmptyTextureLatentTrellis2",
|
||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Voxel.Input("structure_output"),
|
IO.Voxel.Input("structure_or_coords"),
|
||||||
IO.Latent.Input("shape_latent"),
|
IO.Latent.Input("shape_latent"),
|
||||||
IO.Model.Input("model")
|
IO.Model.Input("model")
|
||||||
],
|
],
|
||||||
@ -388,16 +386,21 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output, shape_latent, model):
|
def execute(cls, structure_or_coords, shape_latent, model):
|
||||||
# TODO
|
channels = 32
|
||||||
decoded = structure_output.data.unsqueeze(1)
|
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
decoded = structure_or_coords.data.unsqueeze(1)
|
||||||
in_channels = 32
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
|
|
||||||
|
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
|
||||||
|
coords = structure_or_coords.int()
|
||||||
|
|
||||||
shape_latent = shape_latent["samples"]
|
shape_latent = shape_latent["samples"]
|
||||||
|
if shape_latent.ndim == 4:
|
||||||
|
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||||
shape_latent = shape_norm(shape_latent, coords)
|
shape_latent = shape_norm(shape_latent, coords)
|
||||||
|
|
||||||
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
latent = torch.randn(1, channels, coords.shape[0], 1)
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options = model.model_options.copy()
|
model.model_options = model.model_options.copy()
|
||||||
if "transformer_options" in model.model_options:
|
if "transformer_options" in model.model_options:
|
||||||
@ -406,9 +409,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
model.model_options["transformer_options"] = {}
|
model.model_options["transformer_options"] = {}
|
||||||
|
|
||||||
model.model_options["transformer_options"]["coords"] = coords
|
model.model_options["transformer_options"]["coords"] = coords
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
|
||||||
model.model_options["transformer_options"]["shape_slat"] = shape_latent
|
model.model_options["transformer_options"]["shape_slat"] = shape_latent
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
|
return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
|
||||||
|
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user