mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
trellis2conditioning and a hidden bug
This commit is contained in:
parent
7454979e16
commit
6191cd86bf
@ -826,7 +826,12 @@ class Trellis2(nn.Module):
|
|||||||
self.guidance_interval_txt = [0.6, 0.9]
|
self.guidance_interval_txt = [0.6, 0.9]
|
||||||
|
|
||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
|
# FIXME: should find a way to distinguish between 512/1024 models
|
||||||
|
# currently assumes 1024
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
|
_, cond = context.chunk(2)
|
||||||
|
cond = embeds.chunk(2)[0]
|
||||||
|
context = torch.cat([torch.zeros_like(cond), cond])
|
||||||
mode = kwargs.get("generation_mode")
|
mode = kwargs.get("generation_mode")
|
||||||
coords = kwargs.get("coords")
|
coords = kwargs.get("coords")
|
||||||
transformer_options = kwargs.get("transformer_options")
|
transformer_options = kwargs.get("transformer_options")
|
||||||
@ -837,12 +842,13 @@ class Trellis2(nn.Module):
|
|||||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||||
|
|
||||||
if mode in ["shape_generation", "texture_generation"]:
|
not_struct_mode = mode in ["shape_generation", "texture_generation"]
|
||||||
|
if not_struct_mode:
|
||||||
x = SparseTensor(feats=x, coords=coords)
|
x = SparseTensor(feats=x, coords=coords)
|
||||||
|
|
||||||
if mode == "shape_generation":
|
if mode == "shape_generation":
|
||||||
# TODO
|
# TODO
|
||||||
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
out = self.img2shape(x, timestep, context)
|
||||||
elif mode == "texture_generation":
|
elif mode == "texture_generation":
|
||||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||||
else: # structure
|
else: # structure
|
||||||
@ -855,5 +861,6 @@ class Trellis2(nn.Module):
|
|||||||
if shape_rule:
|
if shape_rule:
|
||||||
out = out.repeat(orig_bsz, 1, 1, 1, 1)
|
out = out.repeat(orig_bsz, 1, 1, 1, 1)
|
||||||
|
|
||||||
out.generation_mode = mode
|
if not_struct_mode:
|
||||||
|
out = out.feats
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -233,25 +233,14 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
image = image[0]
|
image = image[0]
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
image = Image.fromarray(image.numpy())
|
image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||||
|
image = Image.fromarray(image)
|
||||||
max_size = max(image.size)
|
max_size = max(image.size)
|
||||||
scale = min(1, 1024 / max_size)
|
scale = min(1, 1024 / max_size)
|
||||||
if scale < 1:
|
if scale < 1:
|
||||||
image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
output_np = np.array(image)
|
image = torch.tensor(np.array(image)).unsqueeze(0)
|
||||||
alpha = output_np[:, :, 3]
|
|
||||||
bbox = np.argwhere(alpha > 0.8 * 255)
|
|
||||||
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
|
|
||||||
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
|
|
||||||
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
|
|
||||||
size = int(size * 1)
|
|
||||||
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
|
|
||||||
output = image.crop(bbox) # type: ignore
|
|
||||||
output = np.array(output).astype(np.float32) / 255
|
|
||||||
output = output[:, :, :3] * output[:, :, 3:4]
|
|
||||||
|
|
||||||
image = torch.tensor(output)
|
|
||||||
|
|
||||||
# could make 1024 an option
|
# could make 1024 an option
|
||||||
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
||||||
@ -276,7 +265,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output):
|
def execute(cls, structure_output):
|
||||||
decoded = structure_output.data
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = torch.randn(coords.shape[0], in_channels)
|
latent = torch.randn(coords.shape[0], in_channels)
|
||||||
@ -299,7 +288,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output):
|
def execute(cls, structure_output):
|
||||||
# TODO
|
# TODO
|
||||||
decoded = structure_output.data
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user