mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
shape working
This commit is contained in:
parent
5d2548822c
commit
def8947e75
@ -770,14 +770,20 @@ class Trellis2(nn.Module):
|
|||||||
is_1024 = self.img2shape.resolution == 1024
|
is_1024 = self.img2shape.resolution == 1024
|
||||||
coords = transformer_options.get("coords", None)
|
coords = transformer_options.get("coords", None)
|
||||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||||
|
is_512_run = False
|
||||||
|
if mode == "shape_generation_512":
|
||||||
|
is_512_run = True
|
||||||
|
mode = "shape_generation"
|
||||||
if coords is not None:
|
if coords is not None:
|
||||||
x = x.squeeze(-1).transpose(1, 2)
|
x = x.squeeze(-1).transpose(1, 2)
|
||||||
not_struct_mode = True
|
not_struct_mode = True
|
||||||
else:
|
else:
|
||||||
mode = "structure_generation"
|
mode = "structure_generation"
|
||||||
not_struct_mode = False
|
not_struct_mode = False
|
||||||
if is_1024 and mode == "shape_generation":
|
|
||||||
|
if is_1024 and mode == "shape_generation" and not is_512_run:
|
||||||
context = embeds
|
context = embeds
|
||||||
|
|
||||||
sigmas = transformer_options.get("sigmas")[0].item()
|
sigmas = transformer_options.get("sigmas")[0].item()
|
||||||
if sigmas < 1.00001:
|
if sigmas < 1.00001:
|
||||||
timestep *= 1000.0
|
timestep *= 1000.0
|
||||||
@ -786,12 +792,24 @@ class Trellis2(nn.Module):
|
|||||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||||
|
|
||||||
if not_struct_mode:
|
if not_struct_mode:
|
||||||
B, N, C = x.shape
|
orig_bsz = x.shape[0]
|
||||||
|
rule = txt_rule if mode == "texture_generation" else shape_rule
|
||||||
|
|
||||||
if mode == "shape_generation":
|
if rule and orig_bsz > 1:
|
||||||
feats_flat = x.reshape(-1, C)
|
x_eval = x[1].unsqueeze(0)
|
||||||
|
t_eval = timestep[1].unsqueeze(0) if timestep.shape[0] > 1 else timestep
|
||||||
|
c_eval = cond
|
||||||
|
else:
|
||||||
|
x_eval = x
|
||||||
|
t_eval = timestep
|
||||||
|
c_eval = context
|
||||||
|
|
||||||
# 3. inflate coords [N, 4] -> [B*N, 4]
|
B, N, C = x_eval.shape
|
||||||
|
|
||||||
|
if mode in ["shape_generation", "texture_generation"]:
|
||||||
|
feats_flat = x_eval.reshape(-1, C)
|
||||||
|
|
||||||
|
# inflate coords [N, 4] -> [B*N, 4]
|
||||||
coords_list = []
|
coords_list = []
|
||||||
for i in range(B):
|
for i in range(B):
|
||||||
c = coords.clone()
|
c = coords.clone()
|
||||||
@ -799,23 +817,27 @@ class Trellis2(nn.Module):
|
|||||||
coords_list.append(c)
|
coords_list.append(c)
|
||||||
|
|
||||||
batched_coords = torch.cat(coords_list, dim=0)
|
batched_coords = torch.cat(coords_list, dim=0)
|
||||||
else: # TODO: texture
|
else:
|
||||||
# may remove the else if texture doesn't require special handling
|
|
||||||
batched_coords = coords
|
batched_coords = coords
|
||||||
feats_flat = x
|
feats_flat = x_eval
|
||||||
x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
|
||||||
|
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
||||||
|
|
||||||
if mode == "shape_generation":
|
if mode == "shape_generation":
|
||||||
# TODO
|
# TODO
|
||||||
out = self.img2shape(x, timestep, context)
|
if is_512_run:
|
||||||
|
out = self.img2shape_512(x_st, t_eval, c_eval)
|
||||||
|
else:
|
||||||
|
out = self.img2shape(x_st, t_eval, c_eval)
|
||||||
elif mode == "texture_generation":
|
elif mode == "texture_generation":
|
||||||
if self.shape2txt is None:
|
if self.shape2txt is None:
|
||||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||||
slat = transformer_options.get("shape_slat")
|
slat = transformer_options.get("shape_slat")
|
||||||
if slat is None:
|
if slat is None:
|
||||||
raise ValueError("shape_slat can't be None")
|
raise ValueError("shape_slat can't be None")
|
||||||
x = sparse_cat([x, slat])
|
slat.feats = slat.feats.repeat(B, 1)
|
||||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
x_st = sparse_cat([x_st, slat])
|
||||||
|
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||||
else: # structure
|
else: # structure
|
||||||
#timestep = timestep_reshift(timestep)
|
#timestep = timestep_reshift(timestep)
|
||||||
orig_bsz = x.shape[0]
|
orig_bsz = x.shape[0]
|
||||||
@ -828,6 +850,8 @@ class Trellis2(nn.Module):
|
|||||||
|
|
||||||
if not_struct_mode:
|
if not_struct_mode:
|
||||||
out = out.feats
|
out = out.feats
|
||||||
if mode == "shape_generation":
|
if not_struct_mode:
|
||||||
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
||||||
|
if rule and orig_bsz > 1:
|
||||||
|
out = out.repeat(orig_bsz, 1, 1, 1)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -178,6 +178,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, shape_latent_512, vae, target_resolution, max_tokens):
|
def execute(cls, shape_latent_512, vae, target_resolution, max_tokens):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
|
comfy.model_management.load_model_gpu(vae.patcher)
|
||||||
|
|
||||||
feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||||
coords_512 = shape_latent_512["coords"].to(device)
|
coords_512 = shape_latent_512["coords"].to(device)
|
||||||
@ -185,11 +186,9 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
|||||||
slat = shape_norm(feats, coords_512)
|
slat = shape_norm(feats, coords_512)
|
||||||
|
|
||||||
decoder = vae.first_stage_model.shape_dec
|
decoder = vae.first_stage_model.shape_dec
|
||||||
decoder.to(device)
|
|
||||||
|
|
||||||
slat.feats = slat.feats.to(next(decoder.parameters()).dtype)
|
slat.feats = slat.feats.to(next(decoder.parameters()).dtype)
|
||||||
hr_coords = decoder.upsample(slat, upsample_times=4)
|
hr_coords = decoder.upsample(slat, upsample_times=4)
|
||||||
decoder.cpu()
|
|
||||||
|
|
||||||
lr_resolution = 512
|
lr_resolution = 512
|
||||||
hr_resolution = int(target_resolution)
|
hr_resolution = int(target_resolution)
|
||||||
@ -206,7 +205,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
|||||||
break
|
break
|
||||||
hr_resolution -= 128
|
hr_resolution -= 128
|
||||||
|
|
||||||
return IO.NodeOutput(final_coords.cpu())
|
return IO.NodeOutput(final_coords,)
|
||||||
|
|
||||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
||||||
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
||||||
@ -341,11 +340,19 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_or_coords, model):
|
def execute(cls, structure_or_coords, model):
|
||||||
# to accept the upscaled coords
|
# to accept the upscaled coords
|
||||||
if hasattr(structure_or_coords, "data"):
|
is_512_pass = False
|
||||||
|
|
||||||
|
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
|
||||||
decoded = structure_or_coords.data.unsqueeze(1)
|
decoded = structure_or_coords.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
|
is_512_pass = True
|
||||||
|
|
||||||
|
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
|
||||||
|
coords = structure_or_coords.int()
|
||||||
|
is_512_pass = False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
coords = structure_or_coords
|
raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}")
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
# image like format
|
# image like format
|
||||||
latent = torch.randn(1, in_channels, coords.shape[0], 1)
|
latent = torch.randn(1, in_channels, coords.shape[0], 1)
|
||||||
@ -357,7 +364,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
model.model_options["transformer_options"] = {}
|
model.model_options["transformer_options"] = {}
|
||||||
|
|
||||||
model.model_options["transformer_options"]["coords"] = coords
|
model.model_options["transformer_options"]["coords"] = coords
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
if is_512_pass:
|
||||||
|
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
|
||||||
|
else:
|
||||||
|
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||||
return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
|
return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
|
||||||
|
|
||||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user