Device fixes

This commit is contained in:
kijai 2026-06-16 00:52:00 +03:00
parent 5d643b15e4
commit fb04969362
2 changed files with 11 additions and 8 deletions

View File

@ -908,6 +908,7 @@ class SAM3DBody_Render(io.ComfyNode):
return io.NodeOutput(torch.zeros(1, H, W, 3, dtype=torch.float32))
out_device = comfy.model_management.intermediate_device()
out_dtype = comfy.model_management.intermediate_dtype()
bg_t = None if background is None else background.to(device=out_device, dtype=torch.float32)
if bg_t is not None and tuple(bg_t.shape[1:3]) != (H, W): # Match the background to the render resolution
@ -1037,12 +1038,10 @@ class SAM3DBody_Render(io.ComfyNode):
rainbow_tilt_z_deg=rainbow_tilt_z,
person_brightness_falloff=person_palette_falloff,
)
frames_out.append(img)
frames_out.append(img.to(device=out_device, dtype=out_dtype))
pbar.update(1)
out_image = torch.stack(frames_out, dim=0)
if out_image.device != out_device:
out_image = out_image.to(out_device)
return io.NodeOutput(out_image)

View File

@ -40,16 +40,19 @@ def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
if packed is None:
return None, None
unpacked = unpack_masks(packed) # (N, K, Hm, Wm) bool
N, K = unpacked.shape[:2]
N, K = packed.shape[0], packed.shape[1]
if N != B or K == 0:
return None, None
device = comfy.model_management.get_torch_device()
unpacked = unpack_masks(packed.to(device)) # (N, K, Hm, Wm) bool
Hm, Wm = unpacked.shape[2], unpacked.shape[3]
resized = F.interpolate(
unpacked.float().reshape(N * K, 1, Hm, Wm),
size=(H, W), mode="bilinear", align_corners=False,
)
arr = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W).cpu()
arr_gpu = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W)
arr = arr_gpu.cpu()
per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)]
full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32)
@ -57,8 +60,9 @@ def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
for f in range(N):
derived = []
for k in range(K):
b = _bbox_from_mask(arr[f, k])
derived.append(b if b is not None else full_frame_bbox)
# Erosion + argmax bbox on GPU; CPU max_pool2d over N*K full-res masks is slow.
b = _bbox_from_mask(arr_gpu[f, k])
derived.append(b.cpu() if b is not None else full_frame_bbox)
per_frame_bboxes.append(torch.stack(derived, dim=0))
return per_frame_bboxes, per_frame_masks