mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-21 07:52:39 +08:00
Edge case guards
This commit is contained in:
parent
322b3467d8
commit
b70c9d0959
@ -70,6 +70,8 @@ def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device
|
|||||||
mask_h, mask_w = coarse_mask.shape[-2:]
|
mask_h, mask_w = coarse_mask.shape[-2:]
|
||||||
mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h)
|
mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h)
|
||||||
mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h)
|
mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h)
|
||||||
|
if mx2 <= mx1 or my2 <= my1:
|
||||||
|
return _coarse_fallback()
|
||||||
mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0)
|
mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0)
|
||||||
for _ in range(iterations):
|
for _ in range(iterations):
|
||||||
coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False)
|
coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False)
|
||||||
@ -248,6 +250,8 @@ class SAM3_Detect(io.ComfyNode):
|
|||||||
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
|
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
|
idev = comfy.model_management.intermediate_device()
|
||||||
|
all_masks = [m.to(idev) for m in all_masks]
|
||||||
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
|
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
|
||||||
return io.NodeOutput(mask_out, all_bbox_dicts)
|
return io.NodeOutput(mask_out, all_bbox_dicts)
|
||||||
|
|
||||||
@ -297,7 +301,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
|||||||
pbar = comfy.utils.ProgressBar(N)
|
pbar = comfy.utils.ProgressBar(N)
|
||||||
|
|
||||||
text_prompts = None
|
text_prompts = None
|
||||||
if conditioning is not None:
|
if conditioning is not None and len(conditioning) > 0:
|
||||||
text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)]
|
text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)]
|
||||||
elif initial_mask is None:
|
elif initial_mask is None:
|
||||||
raise ValueError("Either initial_mask or conditioning must be provided")
|
raise ValueError("Either initial_mask or conditioning must be provided")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user