rewriting conditioning logic + model code addition

This commit is contained in:
Yousef Rafat 2026-02-16 01:53:53 +02:00
parent 92aa058a58
commit 91fa563b21
2 changed files with 70 additions and 65 deletions

View File

@ -826,8 +826,11 @@ class Trellis2(nn.Module):
def forward(self, x, timestep, context, **kwargs): def forward(self, x, timestep, context, **kwargs):
embeds = kwargs.get("embeds") embeds = kwargs.get("embeds")
mode = kwargs.get("generation_mode") mode = kwargs.get("generation_mode")
sigmas = kwargs.get("sigmas")[0].item() transformer_options = kwargs.get("transformer_options")
cond = context.chunk(2) sigmas = transformer_options.get("sigmas")[0].item()
if sigmas < 1.00001:
timestep *= 1000.0
cond = context.chunk(2)[1]
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
@ -838,6 +841,9 @@ class Trellis2(nn.Module):
out = self.shape2txt(x, timestep, context if not txt_rule else cond) out = self.shape2txt(x, timestep, context if not txt_rule else cond)
else: # structure else: # structure
timestep = timestep_reshift(timestep) timestep = timestep_reshift(timestep)
if shape_rule:
x = x[0].unsqueeze(0)
timestep = timestep[0]
out = self.structure_model(x, timestep, context if not shape_rule else cond) out = self.structure_model(x, timestep, context if not shape_rule else cond)
out.generation_mode = mode out.generation_mode = mode

View File

@ -4,7 +4,6 @@ import torch
from comfy.ldm.trellis2.model import SparseTensor from comfy.ldm.trellis2.model import SparseTensor
import comfy.model_management import comfy.model_management
import comfy.model_patcher import comfy.model_patcher
from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode
shape_slat_normalization = { shape_slat_normalization = {
"mean": torch.tensor([ "mean": torch.tensor([
@ -36,86 +35,85 @@ tex_slat_normalization = {
])[None] ])[None]
} }
def smart_crop_square( dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
image: torch.Tensor, dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
background_color=(128, 128, 128),
): def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)):
nz = torch.nonzero(mask[0] > 0.5)
if nz.shape[0] == 0:
C, H, W = image.shape
side = max(H, W)
canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray
canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image
return canvas
y_min, x_min = nz.min(dim=0)[0]
y_max, x_max = nz.max(dim=0)[0]
obj_w, obj_h = x_max - x_min, y_max - y_min
center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2
side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2))
half_side = side / 2
x1, y1 = int(center_x - half_side), int(center_y - half_side)
x2, y2 = x1 + side, y1 + side
C, H, W = image.shape C, H, W = image.shape
size = max(H, W) canvas = torch.ones((C, side, side), device=image.device)
canvas = torch.empty(
(C, size, size),
dtype=image.dtype,
device=image.device
)
for c in range(C): for c in range(C):
canvas[c].fill_(background_color[c]) canvas[c] *= (bg_color[c] / 255.0)
top = (size - H) // 2
left = (size - W) // 2 src_x1, src_y1 = max(0, x1), max(0, y1)
canvas[:, top:top + H, left:left + W] = image src_x2, src_y2 = min(W, x2), min(H, y2)
dst_x1, dst_y1 = max(0, -x1), max(0, -y1)
dst_x2 = dst_x1 + (src_x2 - src_x1)
dst_y2 = dst_y1 + (src_y2 - src_y1)
canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = image[:, src_y1:src_y2, src_x1:src_x2]
return canvas return canvas
def run_conditioning( def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"):
model, model_internal = model.model
image: torch.Tensor, device = comfy.model_management.intermediate_device()
include_1024: bool = True, torch_device = comfy.model_management.get_torch_device()
background_color: str = "black",
):
# TODO: should check if normalization was applied in these steps
model = model.model
device = comfy.model_management.intermediate_device() # replaces .cpu()
torch_device = comfy.model_management.get_torch_device() # replaces .cuda()
bg_colors = {
"black": (0, 0, 0),
"gray": (128, 128, 128),
"white": (255, 255, 255),
}
bg_color = bg_colors.get(background_color, (128, 128, 128))
# Convert image to PIL bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)}
if image.dim() == 4: bg_rgb = bg_colors.get(background_color, (128, 128, 128))
pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8)
else:
pil_image = (image * 255).clip(0, 255).to(torch.uint8)
pil_image = pil_image.movedim(-1, 0) img_t = image[0].movedim(-1, 0).to(torch_device).float()
pil_image = smart_crop_square(pil_image, background_color=bg_color) mask_t = mask[0].to(torch_device).float()
if mask_t.ndim == 2:
mask_t = mask_t.unsqueeze(0)
model.image_size = 512 cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb)
def set_image_size(image, image_size=512):
if image.ndim == 3:
image = image.unsqueeze(0)
to_pil = ToPILImage() def prepare_tensor(img, size):
to_tensor = ToTensor() resized = torch.nn.functional.interpolate(
resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS) img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False
)
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
pil_img = to_pil(image.squeeze(0)) model_internal.image_size = 512
resized_pil = resizer(pil_img) input_512 = prepare_tensor(cropped_img, 512)
image = to_tensor(resized_pil).unsqueeze(0) cond_512 = model_internal(input_512)[0]
return image.to(torch_device).float()
pil_image = set_image_size(pil_image, 512)
cond_512 = model(pil_image)[0]
cond_1024 = None cond_1024 = None
if include_1024: if include_1024:
model.image_size = 1024 model_internal.image_size = 1024
pil_image = set_image_size(pil_image, 1024) input_1024 = prepare_tensor(cropped_img, 1024)
cond_1024 = model(pil_image)[0] cond_1024 = model_internal(input_1024)[0]
neg_cond = torch.zeros_like(cond_512)
conditioning = { conditioning = {
'cond_512': cond_512.to(device), 'cond_512': cond_512.to(device),
'neg_cond': neg_cond.to(device), 'neg_cond': torch.zeros_like(cond_512).to(device),
} }
if cond_1024 is not None: if cond_1024 is not None:
conditioning['cond_1024'] = cond_1024.to(device) conditioning['cond_1024'] = cond_1024.to(device)
preprocessed_tensor = pil_image.to(torch.float32) / 255.0 preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu()
preprocessed_tensor = preprocessed_tensor.unsqueeze(0)
return conditioning, preprocessed_tensor return conditioning, preprocessed_tensor
@ -213,6 +211,7 @@ class Trellis2Conditioning(IO.ComfyNode):
inputs=[ inputs=[
IO.ClipVision.Input("clip_vision_model"), IO.ClipVision.Input("clip_vision_model"),
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Mask.Input("mask"),
IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black") IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black")
], ],
outputs=[ outputs=[
@ -222,9 +221,9 @@ class Trellis2Conditioning(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, clip_vision_model, image, background_color) -> IO.NodeOutput: def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput:
# could make 1024 an option # could make 1024 an option
conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
embeds = conditioning["cond_1024"] # should add that embeds = conditioning["cond_1024"] # should add that
positive = [[conditioning["cond_512"], {"embeds": embeds}]] positive = [[conditioning["cond_512"], {"embeds": embeds}]]
negative = [[conditioning["neg_cond"], {"embeds": embeds}]] negative = [[conditioning["neg_cond"], {"embeds": embeds}]]