mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
rewriting conditioning logic + model code addition
This commit is contained in:
parent
92aa058a58
commit
91fa563b21
@ -826,8 +826,11 @@ class Trellis2(nn.Module):
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
embeds = kwargs.get("embeds")
|
||||
mode = kwargs.get("generation_mode")
|
||||
sigmas = kwargs.get("sigmas")[0].item()
|
||||
cond = context.chunk(2)
|
||||
transformer_options = kwargs.get("transformer_options")
|
||||
sigmas = transformer_options.get("sigmas")[0].item()
|
||||
if sigmas < 1.00001:
|
||||
timestep *= 1000.0
|
||||
cond = context.chunk(2)[1]
|
||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||
|
||||
@ -838,6 +841,9 @@ class Trellis2(nn.Module):
|
||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||
else: # structure
|
||||
timestep = timestep_reshift(timestep)
|
||||
if shape_rule:
|
||||
x = x[0].unsqueeze(0)
|
||||
timestep = timestep[0]
|
||||
out = self.structure_model(x, timestep, context if not shape_rule else cond)
|
||||
|
||||
out.generation_mode = mode
|
||||
|
||||
@ -4,7 +4,6 @@ import torch
|
||||
from comfy.ldm.trellis2.model import SparseTensor
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode
|
||||
|
||||
shape_slat_normalization = {
|
||||
"mean": torch.tensor([
|
||||
@ -36,86 +35,85 @@ tex_slat_normalization = {
|
||||
])[None]
|
||||
}
|
||||
|
||||
def smart_crop_square(
|
||||
image: torch.Tensor,
|
||||
background_color=(128, 128, 128),
|
||||
):
|
||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
||||
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
||||
|
||||
def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)):
|
||||
nz = torch.nonzero(mask[0] > 0.5)
|
||||
if nz.shape[0] == 0:
|
||||
C, H, W = image.shape
|
||||
side = max(H, W)
|
||||
canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray
|
||||
canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image
|
||||
return canvas
|
||||
|
||||
y_min, x_min = nz.min(dim=0)[0]
|
||||
y_max, x_max = nz.max(dim=0)[0]
|
||||
|
||||
obj_w, obj_h = x_max - x_min, y_max - y_min
|
||||
center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2
|
||||
|
||||
side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2))
|
||||
half_side = side / 2
|
||||
|
||||
x1, y1 = int(center_x - half_side), int(center_y - half_side)
|
||||
x2, y2 = x1 + side, y1 + side
|
||||
|
||||
C, H, W = image.shape
|
||||
size = max(H, W)
|
||||
canvas = torch.empty(
|
||||
(C, size, size),
|
||||
dtype=image.dtype,
|
||||
device=image.device
|
||||
)
|
||||
canvas = torch.ones((C, side, side), device=image.device)
|
||||
for c in range(C):
|
||||
canvas[c].fill_(background_color[c])
|
||||
top = (size - H) // 2
|
||||
left = (size - W) // 2
|
||||
canvas[:, top:top + H, left:left + W] = image
|
||||
canvas[c] *= (bg_color[c] / 255.0)
|
||||
|
||||
src_x1, src_y1 = max(0, x1), max(0, y1)
|
||||
src_x2, src_y2 = min(W, x2), min(H, y2)
|
||||
|
||||
dst_x1, dst_y1 = max(0, -x1), max(0, -y1)
|
||||
dst_x2 = dst_x1 + (src_x2 - src_x1)
|
||||
dst_y2 = dst_y1 + (src_y2 - src_y1)
|
||||
|
||||
canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = image[:, src_y1:src_y2, src_x1:src_x2]
|
||||
|
||||
return canvas
|
||||
|
||||
def run_conditioning(
|
||||
model,
|
||||
image: torch.Tensor,
|
||||
include_1024: bool = True,
|
||||
background_color: str = "black",
|
||||
):
|
||||
# TODO: should check if normalization was applied in these steps
|
||||
model = model.model
|
||||
device = comfy.model_management.intermediate_device() # replaces .cpu()
|
||||
torch_device = comfy.model_management.get_torch_device() # replaces .cuda()
|
||||
bg_colors = {
|
||||
"black": (0, 0, 0),
|
||||
"gray": (128, 128, 128),
|
||||
"white": (255, 255, 255),
|
||||
}
|
||||
bg_color = bg_colors.get(background_color, (128, 128, 128))
|
||||
def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"):
|
||||
model_internal = model.model
|
||||
device = comfy.model_management.intermediate_device()
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
|
||||
# Convert image to PIL
|
||||
if image.dim() == 4:
|
||||
pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8)
|
||||
else:
|
||||
pil_image = (image * 255).clip(0, 255).to(torch.uint8)
|
||||
bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)}
|
||||
bg_rgb = bg_colors.get(background_color, (128, 128, 128))
|
||||
|
||||
pil_image = pil_image.movedim(-1, 0)
|
||||
pil_image = smart_crop_square(pil_image, background_color=bg_color)
|
||||
img_t = image[0].movedim(-1, 0).to(torch_device).float()
|
||||
mask_t = mask[0].to(torch_device).float()
|
||||
if mask_t.ndim == 2:
|
||||
mask_t = mask_t.unsqueeze(0)
|
||||
|
||||
model.image_size = 512
|
||||
def set_image_size(image, image_size=512):
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb)
|
||||
|
||||
to_pil = ToPILImage()
|
||||
to_tensor = ToTensor()
|
||||
resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS)
|
||||
def prepare_tensor(img, size):
|
||||
resized = torch.nn.functional.interpolate(
|
||||
img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False
|
||||
)
|
||||
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||
|
||||
pil_img = to_pil(image.squeeze(0))
|
||||
resized_pil = resizer(pil_img)
|
||||
image = to_tensor(resized_pil).unsqueeze(0)
|
||||
|
||||
return image.to(torch_device).float()
|
||||
|
||||
pil_image = set_image_size(pil_image, 512)
|
||||
cond_512 = model(pil_image)[0]
|
||||
model_internal.image_size = 512
|
||||
input_512 = prepare_tensor(cropped_img, 512)
|
||||
cond_512 = model_internal(input_512)[0]
|
||||
|
||||
cond_1024 = None
|
||||
if include_1024:
|
||||
model.image_size = 1024
|
||||
pil_image = set_image_size(pil_image, 1024)
|
||||
cond_1024 = model(pil_image)[0]
|
||||
|
||||
neg_cond = torch.zeros_like(cond_512)
|
||||
model_internal.image_size = 1024
|
||||
input_1024 = prepare_tensor(cropped_img, 1024)
|
||||
cond_1024 = model_internal(input_1024)[0]
|
||||
|
||||
conditioning = {
|
||||
'cond_512': cond_512.to(device),
|
||||
'neg_cond': neg_cond.to(device),
|
||||
'neg_cond': torch.zeros_like(cond_512).to(device),
|
||||
}
|
||||
if cond_1024 is not None:
|
||||
conditioning['cond_1024'] = cond_1024.to(device)
|
||||
|
||||
preprocessed_tensor = pil_image.to(torch.float32) / 255.0
|
||||
preprocessed_tensor = preprocessed_tensor.unsqueeze(0)
|
||||
preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu()
|
||||
|
||||
return conditioning, preprocessed_tensor
|
||||
|
||||
@ -213,6 +211,7 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.ClipVision.Input("clip_vision_model"),
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black")
|
||||
],
|
||||
outputs=[
|
||||
@ -222,9 +221,9 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_model, image, background_color) -> IO.NodeOutput:
|
||||
def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput:
|
||||
# could make 1024 an option
|
||||
conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color)
|
||||
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
||||
embeds = conditioning["cond_1024"] # should add that
|
||||
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||
negative = [[conditioning["neg_cond"], {"embeds": embeds}]]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user