mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
trellis2conditioning and a hidden bug
This commit is contained in:
parent
7454979e16
commit
6191cd86bf
@ -826,7 +826,12 @@ class Trellis2(nn.Module):
|
||||
self.guidance_interval_txt = [0.6, 0.9]
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
# FIXME: should find a way to distinguish between 512/1024 models
|
||||
# currently assumes 1024
|
||||
embeds = kwargs.get("embeds")
|
||||
_, cond = context.chunk(2)
|
||||
cond = embeds.chunk(2)[0]
|
||||
context = torch.cat([torch.zeros_like(cond), cond])
|
||||
mode = kwargs.get("generation_mode")
|
||||
coords = kwargs.get("coords")
|
||||
transformer_options = kwargs.get("transformer_options")
|
||||
@ -837,12 +842,13 @@ class Trellis2(nn.Module):
|
||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||
|
||||
if mode in ["shape_generation", "texture_generation"]:
|
||||
not_struct_mode = mode in ["shape_generation", "texture_generation"]
|
||||
if not_struct_mode:
|
||||
x = SparseTensor(feats=x, coords=coords)
|
||||
|
||||
if mode == "shape_generation":
|
||||
# TODO
|
||||
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
||||
out = self.img2shape(x, timestep, context)
|
||||
elif mode == "texture_generation":
|
||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||
else: # structure
|
||||
@ -855,5 +861,6 @@ class Trellis2(nn.Module):
|
||||
if shape_rule:
|
||||
out = out.repeat(orig_bsz, 1, 1, 1, 1)
|
||||
|
||||
out.generation_mode = mode
|
||||
if not_struct_mode:
|
||||
out = out.feats
|
||||
return out
|
||||
|
||||
@ -233,25 +233,14 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
image = image[0]
|
||||
|
||||
# TODO
|
||||
image = Image.fromarray(image.numpy())
|
||||
image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
max_size = max(image.size)
|
||||
scale = min(1, 1024 / max_size)
|
||||
if scale < 1:
|
||||
image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
||||
|
||||
output_np = np.array(image)
|
||||
alpha = output_np[:, :, 3]
|
||||
bbox = np.argwhere(alpha > 0.8 * 255)
|
||||
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
|
||||
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
|
||||
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
|
||||
size = int(size * 1)
|
||||
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
|
||||
output = image.crop(bbox) # type: ignore
|
||||
output = np.array(output).astype(np.float32) / 255
|
||||
output = output[:, :, :3] * output[:, :, 3:4]
|
||||
|
||||
image = torch.tensor(output)
|
||||
image = torch.tensor(np.array(image)).unsqueeze(0)
|
||||
|
||||
# could make 1024 an option
|
||||
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
||||
@ -276,7 +265,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, structure_output):
|
||||
decoded = structure_output.data
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
in_channels = 32
|
||||
latent = torch.randn(coords.shape[0], in_channels)
|
||||
@ -299,7 +288,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, structure_output):
|
||||
# TODO
|
||||
decoded = structure_output.data
|
||||
decoded = structure_output.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
in_channels = 32
|
||||
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user