mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
rewriting conditioning logic + model code addition
This commit is contained in:
parent
92aa058a58
commit
91fa563b21
@ -826,8 +826,11 @@ class Trellis2(nn.Module):
|
|||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
mode = kwargs.get("generation_mode")
|
mode = kwargs.get("generation_mode")
|
||||||
sigmas = kwargs.get("sigmas")[0].item()
|
transformer_options = kwargs.get("transformer_options")
|
||||||
cond = context.chunk(2)
|
sigmas = transformer_options.get("sigmas")[0].item()
|
||||||
|
if sigmas < 1.00001:
|
||||||
|
timestep *= 1000.0
|
||||||
|
cond = context.chunk(2)[1]
|
||||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||||
|
|
||||||
@ -838,6 +841,9 @@ class Trellis2(nn.Module):
|
|||||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||||
else: # structure
|
else: # structure
|
||||||
timestep = timestep_reshift(timestep)
|
timestep = timestep_reshift(timestep)
|
||||||
|
if shape_rule:
|
||||||
|
x = x[0].unsqueeze(0)
|
||||||
|
timestep = timestep[0]
|
||||||
out = self.structure_model(x, timestep, context if not shape_rule else cond)
|
out = self.structure_model(x, timestep, context if not shape_rule else cond)
|
||||||
|
|
||||||
out.generation_mode = mode
|
out.generation_mode = mode
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import torch
|
|||||||
from comfy.ldm.trellis2.model import SparseTensor
|
from comfy.ldm.trellis2.model import SparseTensor
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode
|
|
||||||
|
|
||||||
shape_slat_normalization = {
|
shape_slat_normalization = {
|
||||||
"mean": torch.tensor([
|
"mean": torch.tensor([
|
||||||
@ -36,86 +35,85 @@ tex_slat_normalization = {
|
|||||||
])[None]
|
])[None]
|
||||||
}
|
}
|
||||||
|
|
||||||
def smart_crop_square(
|
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
||||||
image: torch.Tensor,
|
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
||||||
background_color=(128, 128, 128),
|
|
||||||
):
|
def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)):
|
||||||
|
nz = torch.nonzero(mask[0] > 0.5)
|
||||||
|
if nz.shape[0] == 0:
|
||||||
|
C, H, W = image.shape
|
||||||
|
side = max(H, W)
|
||||||
|
canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray
|
||||||
|
canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
y_min, x_min = nz.min(dim=0)[0]
|
||||||
|
y_max, x_max = nz.max(dim=0)[0]
|
||||||
|
|
||||||
|
obj_w, obj_h = x_max - x_min, y_max - y_min
|
||||||
|
center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2
|
||||||
|
|
||||||
|
side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2))
|
||||||
|
half_side = side / 2
|
||||||
|
|
||||||
|
x1, y1 = int(center_x - half_side), int(center_y - half_side)
|
||||||
|
x2, y2 = x1 + side, y1 + side
|
||||||
|
|
||||||
C, H, W = image.shape
|
C, H, W = image.shape
|
||||||
size = max(H, W)
|
canvas = torch.ones((C, side, side), device=image.device)
|
||||||
canvas = torch.empty(
|
|
||||||
(C, size, size),
|
|
||||||
dtype=image.dtype,
|
|
||||||
device=image.device
|
|
||||||
)
|
|
||||||
for c in range(C):
|
for c in range(C):
|
||||||
canvas[c].fill_(background_color[c])
|
canvas[c] *= (bg_color[c] / 255.0)
|
||||||
top = (size - H) // 2
|
|
||||||
left = (size - W) // 2
|
src_x1, src_y1 = max(0, x1), max(0, y1)
|
||||||
canvas[:, top:top + H, left:left + W] = image
|
src_x2, src_y2 = min(W, x2), min(H, y2)
|
||||||
|
|
||||||
|
dst_x1, dst_y1 = max(0, -x1), max(0, -y1)
|
||||||
|
dst_x2 = dst_x1 + (src_x2 - src_x1)
|
||||||
|
dst_y2 = dst_y1 + (src_y2 - src_y1)
|
||||||
|
|
||||||
|
canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = image[:, src_y1:src_y2, src_x1:src_x2]
|
||||||
|
|
||||||
return canvas
|
return canvas
|
||||||
|
|
||||||
def run_conditioning(
|
def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"):
|
||||||
model,
|
model_internal = model.model
|
||||||
image: torch.Tensor,
|
device = comfy.model_management.intermediate_device()
|
||||||
include_1024: bool = True,
|
torch_device = comfy.model_management.get_torch_device()
|
||||||
background_color: str = "black",
|
|
||||||
):
|
|
||||||
# TODO: should check if normalization was applied in these steps
|
|
||||||
model = model.model
|
|
||||||
device = comfy.model_management.intermediate_device() # replaces .cpu()
|
|
||||||
torch_device = comfy.model_management.get_torch_device() # replaces .cuda()
|
|
||||||
bg_colors = {
|
|
||||||
"black": (0, 0, 0),
|
|
||||||
"gray": (128, 128, 128),
|
|
||||||
"white": (255, 255, 255),
|
|
||||||
}
|
|
||||||
bg_color = bg_colors.get(background_color, (128, 128, 128))
|
|
||||||
|
|
||||||
# Convert image to PIL
|
bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)}
|
||||||
if image.dim() == 4:
|
bg_rgb = bg_colors.get(background_color, (128, 128, 128))
|
||||||
pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8)
|
|
||||||
else:
|
|
||||||
pil_image = (image * 255).clip(0, 255).to(torch.uint8)
|
|
||||||
|
|
||||||
pil_image = pil_image.movedim(-1, 0)
|
img_t = image[0].movedim(-1, 0).to(torch_device).float()
|
||||||
pil_image = smart_crop_square(pil_image, background_color=bg_color)
|
mask_t = mask[0].to(torch_device).float()
|
||||||
|
if mask_t.ndim == 2:
|
||||||
|
mask_t = mask_t.unsqueeze(0)
|
||||||
|
|
||||||
model.image_size = 512
|
cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb)
|
||||||
def set_image_size(image, image_size=512):
|
|
||||||
if image.ndim == 3:
|
|
||||||
image = image.unsqueeze(0)
|
|
||||||
|
|
||||||
to_pil = ToPILImage()
|
def prepare_tensor(img, size):
|
||||||
to_tensor = ToTensor()
|
resized = torch.nn.functional.interpolate(
|
||||||
resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS)
|
img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False
|
||||||
|
)
|
||||||
|
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||||
|
|
||||||
pil_img = to_pil(image.squeeze(0))
|
model_internal.image_size = 512
|
||||||
resized_pil = resizer(pil_img)
|
input_512 = prepare_tensor(cropped_img, 512)
|
||||||
image = to_tensor(resized_pil).unsqueeze(0)
|
cond_512 = model_internal(input_512)[0]
|
||||||
|
|
||||||
return image.to(torch_device).float()
|
|
||||||
|
|
||||||
pil_image = set_image_size(pil_image, 512)
|
|
||||||
cond_512 = model(pil_image)[0]
|
|
||||||
|
|
||||||
cond_1024 = None
|
cond_1024 = None
|
||||||
if include_1024:
|
if include_1024:
|
||||||
model.image_size = 1024
|
model_internal.image_size = 1024
|
||||||
pil_image = set_image_size(pil_image, 1024)
|
input_1024 = prepare_tensor(cropped_img, 1024)
|
||||||
cond_1024 = model(pil_image)[0]
|
cond_1024 = model_internal(input_1024)[0]
|
||||||
|
|
||||||
neg_cond = torch.zeros_like(cond_512)
|
|
||||||
|
|
||||||
conditioning = {
|
conditioning = {
|
||||||
'cond_512': cond_512.to(device),
|
'cond_512': cond_512.to(device),
|
||||||
'neg_cond': neg_cond.to(device),
|
'neg_cond': torch.zeros_like(cond_512).to(device),
|
||||||
}
|
}
|
||||||
if cond_1024 is not None:
|
if cond_1024 is not None:
|
||||||
conditioning['cond_1024'] = cond_1024.to(device)
|
conditioning['cond_1024'] = cond_1024.to(device)
|
||||||
|
|
||||||
preprocessed_tensor = pil_image.to(torch.float32) / 255.0
|
preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu()
|
||||||
preprocessed_tensor = preprocessed_tensor.unsqueeze(0)
|
|
||||||
|
|
||||||
return conditioning, preprocessed_tensor
|
return conditioning, preprocessed_tensor
|
||||||
|
|
||||||
@ -213,6 +211,7 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.ClipVision.Input("clip_vision_model"),
|
IO.ClipVision.Input("clip_vision_model"),
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
|
IO.Mask.Input("mask"),
|
||||||
IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black")
|
IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black")
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
@ -222,9 +221,9 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip_vision_model, image, background_color) -> IO.NodeOutput:
|
def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput:
|
||||||
# could make 1024 an option
|
# could make 1024 an option
|
||||||
conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color)
|
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
||||||
embeds = conditioning["cond_1024"] # should add that
|
embeds = conditioning["cond_1024"] # should add that
|
||||||
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||||
negative = [[conditioning["neg_cond"], {"embeds": embeds}]]
|
negative = [[conditioning["neg_cond"], {"embeds": embeds}]]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user