Stretch instead of pad as the model seems to expect that

This commit is contained in:
kijai 2026-05-01 16:10:06 +03:00
parent b243a4d3ae
commit eb0551d28d

View File

@ -463,23 +463,19 @@ class SDPoseKeypointExtractor(io.ComfyNode):
model_h = int(head.heatmap_size[1]) * 4 # 256 * 4 = 1024
def _resize_to_model(imgs):
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
"""Stretch BHWC images to (model_h, model_w), model expects no aspect preservation."""
h, w = imgs.shape[-3], imgs.shape[-2]
scale = min(model_h / h, model_w / w)
sh, sw = int(round(h * scale)), int(round(w * scale))
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
method = "area" if (model_h <= h and model_w <= w) else "bilinear"
chw = imgs.permute(0, 3, 1, 2).float()
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
return padded.permute(0, 2, 3, 1), scale, pt, pl
scaled = comfy.utils.common_upscale(chw, model_w, model_h, upscale_method=method, crop="disabled")
return scaled.permute(0, 2, 3, 1), model_w / w, model_h / h
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
def _remap_keypoints(kp, scale_x, scale_y, offset_x=0, offset_y=0):
"""Remap keypoints from model space back to original image space."""
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
invalid = kp[..., 0] < 0
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
kp[..., 0] = kp[..., 0] / scale_x + offset_x
kp[..., 1] = kp[..., 1] / scale_y + offset_y
kp[invalid] = -1
return kp
@ -529,18 +525,18 @@ class SDPoseKeypointExtractor(io.ComfyNode):
continue
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
crop_resized, sx, sy = _resize_to_model(crop)
latent_crop = vae.encode(crop_resized)
kp_batch, sc_batch = _run_on_latent(latent_crop)
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
kp = _remap_keypoints(kp_batch[0], sx, sy, x1, y1)
img_keypoints.append(kp)
img_scores.append(sc_batch[0])
else:
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
img_resized, sx, sy = _resize_to_model(img)
latent_img = vae.encode(img_resized)
kp_batch, sc_batch = _run_on_latent(latent_img)
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
img_keypoints.append(_remap_keypoints(kp_batch[0], sx, sy))
img_scores.append(sc_batch[0])
all_keypoints.append(img_keypoints)
@ -549,12 +545,12 @@ class SDPoseKeypointExtractor(io.ComfyNode):
else: # full-image mode, batched
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
batch_resized, sx, sy = _resize_to_model(image[batch_start:batch_start + batch_size])
latent_batch = vae.encode(batch_resized)
kp_batch, sc_batch = _run_on_latent(latent_batch)
for kp, sc in zip(kp_batch, sc_batch):
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
all_keypoints.append([_remap_keypoints(kp, sx, sy)])
all_scores.append([sc])
pbar.update(len(kp_batch))