Strip possible input alpha channel

This commit is contained in:
kijai 2026-04-16 13:55:02 +03:00
parent 30a73b3aac
commit 1bb6f20bad

View File

@ -61,7 +61,7 @@ def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device
if cx2 <= cx1 or cy2 <= cy1:
return _coarse_fallback()
crop = orig_image_hwc[cy1:cy2, cx1:cx2]
crop = orig_image_hwc[cy1:cy2, cx1:cx2, :3]
crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
crop_frame = crop_1008.to(device=device, dtype=dtype)
crop_h, crop_w = cy2 - cy1, cx2 - cx1
@ -115,8 +115,7 @@ class SAM3_Detect(io.ComfyNode):
@classmethod
def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput:
B, H, W, C = image.shape
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
image_in = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
# Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors.
# Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame).
@ -291,7 +290,7 @@ class SAM3_VideoTrack(io.ComfyNode):
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
frames = images.movedim(-1, 1)
frames = images[..., :3].movedim(-1, 1)
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
init_masks = None