From fb049693627167d0012627dff85b4f03fe7390bc Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 16 Jun 2026 00:52:00 +0300 Subject: [PATCH] Device fixes --- comfy_extras/nodes_sam3d_body.py | 5 ++--- comfy_extras/sam3d_body/utils.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_sam3d_body.py b/comfy_extras/nodes_sam3d_body.py index cac17a7cb..ff4328de0 100644 --- a/comfy_extras/nodes_sam3d_body.py +++ b/comfy_extras/nodes_sam3d_body.py @@ -908,6 +908,7 @@ class SAM3DBody_Render(io.ComfyNode): return io.NodeOutput(torch.zeros(1, H, W, 3, dtype=torch.float32)) out_device = comfy.model_management.intermediate_device() + out_dtype = comfy.model_management.intermediate_dtype() bg_t = None if background is None else background.to(device=out_device, dtype=torch.float32) if bg_t is not None and tuple(bg_t.shape[1:3]) != (H, W): # Match the background to the render resolution @@ -1037,12 +1038,10 @@ class SAM3DBody_Render(io.ComfyNode): rainbow_tilt_z_deg=rainbow_tilt_z, person_brightness_falloff=person_palette_falloff, ) - frames_out.append(img) + frames_out.append(img.to(device=out_device, dtype=out_dtype)) pbar.update(1) out_image = torch.stack(frames_out, dim=0) - if out_image.device != out_device: - out_image = out_image.to(out_device) return io.NodeOutput(out_image) diff --git a/comfy_extras/sam3d_body/utils.py b/comfy_extras/sam3d_body/utils.py index d46337470..40bdc6a97 100644 --- a/comfy_extras/sam3d_body/utils.py +++ b/comfy_extras/sam3d_body/utils.py @@ -40,16 +40,19 @@ def inputs_from_sam3_track(track_data, B: int, H: int, W: int): packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None if packed is None: return None, None - unpacked = unpack_masks(packed) # (N, K, Hm, Wm) bool - N, K = unpacked.shape[:2] + N, K = packed.shape[0], packed.shape[1] if N != B or K == 0: return None, None + + device = comfy.model_management.get_torch_device() + unpacked = unpack_masks(packed.to(device)) # (N, K, Hm, Wm) bool Hm, Wm = unpacked.shape[2], unpacked.shape[3] resized = F.interpolate( unpacked.float().reshape(N * K, 1, Hm, Wm), size=(H, W), mode="bilinear", align_corners=False, ) - arr = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W).cpu() + arr_gpu = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W) + arr = arr_gpu.cpu() per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)] full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32) @@ -57,8 +60,9 @@ def inputs_from_sam3_track(track_data, B: int, H: int, W: int): for f in range(N): derived = [] for k in range(K): - b = _bbox_from_mask(arr[f, k]) - derived.append(b if b is not None else full_frame_bbox) + # Erosion + argmax bbox on GPU; CPU max_pool2d over N*K full-res masks is slow. + b = _bbox_from_mask(arr_gpu[f, k]) + derived.append(b.cpu() if b is not None else full_frame_bbox) per_frame_bboxes.append(torch.stack(derived, dim=0)) return per_frame_bboxes, per_frame_masks