mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
Edge case guards
This commit is contained in:
parent
322b3467d8
commit
b70c9d0959
@ -70,6 +70,8 @@ def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device
|
||||
mask_h, mask_w = coarse_mask.shape[-2:]
|
||||
mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h)
|
||||
mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h)
|
||||
if mx2 <= mx1 or my2 <= my1:
|
||||
return _coarse_fallback()
|
||||
mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0)
|
||||
for _ in range(iterations):
|
||||
coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False)
|
||||
@ -248,6 +250,8 @@ class SAM3_Detect(io.ComfyNode):
|
||||
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
|
||||
pbar.update(1)
|
||||
|
||||
idev = comfy.model_management.intermediate_device()
|
||||
all_masks = [m.to(idev) for m in all_masks]
|
||||
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
|
||||
return io.NodeOutput(mask_out, all_bbox_dicts)
|
||||
|
||||
@ -297,7 +301,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
||||
pbar = comfy.utils.ProgressBar(N)
|
||||
|
||||
text_prompts = None
|
||||
if conditioning is not None:
|
||||
if conditioning is not None and len(conditioning) > 0:
|
||||
text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)]
|
||||
elif initial_mask is None:
|
||||
raise ValueError("Either initial_mask or conditioning must be provided")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user