trellis2conditioning and a hidden bug

This commit is contained in:
Yousef Rafat 2026-02-19 22:05:33 +02:00
parent 7454979e16
commit 6191cd86bf
2 changed files with 15 additions and 19 deletions

View File

@ -826,7 +826,12 @@ class Trellis2(nn.Module):
self.guidance_interval_txt = [0.6, 0.9]
def forward(self, x, timestep, context, **kwargs):
# FIXME: should find a way to distinguish between 512/1024 models
# currently assumes 1024
embeds = kwargs.get("embeds")
_, cond = context.chunk(2)
cond = embeds.chunk(2)[0]
context = torch.cat([torch.zeros_like(cond), cond])
mode = kwargs.get("generation_mode")
coords = kwargs.get("coords")
transformer_options = kwargs.get("transformer_options")
@ -837,12 +842,13 @@ class Trellis2(nn.Module):
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
if mode in ["shape_generation", "texture_generation"]:
not_struct_mode = mode in ["shape_generation", "texture_generation"]
if not_struct_mode:
x = SparseTensor(feats=x, coords=coords)
if mode == "shape_generation":
# TODO
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
out = self.img2shape(x, timestep, context)
elif mode == "texture_generation":
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
else: # structure
@ -855,5 +861,6 @@ class Trellis2(nn.Module):
if shape_rule:
out = out.repeat(orig_bsz, 1, 1, 1, 1)
out.generation_mode = mode
if not_struct_mode:
out = out.feats
return out

View File

@ -233,25 +233,14 @@ class Trellis2Conditioning(IO.ComfyNode):
image = image[0]
# TODO
image = Image.fromarray(image.numpy())
image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
image = Image.fromarray(image)
max_size = max(image.size)
scale = min(1, 1024 / max_size)
if scale < 1:
image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
output_np = np.array(image)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = image.crop(bbox) # type: ignore
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
image = torch.tensor(output)
image = torch.tensor(np.array(image)).unsqueeze(0)
# could make 1024 an option
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
@ -276,7 +265,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
@classmethod
def execute(cls, structure_output):
decoded = structure_output.data
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
in_channels = 32
latent = torch.randn(coords.shape[0], in_channels)
@ -299,7 +288,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
@classmethod
def execute(cls, structure_output):
# TODO
decoded = structure_output.data
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
in_channels = 32
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])