mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
Device fixes
This commit is contained in:
parent
5d643b15e4
commit
fb04969362
@ -908,6 +908,7 @@ class SAM3DBody_Render(io.ComfyNode):
|
||||
return io.NodeOutput(torch.zeros(1, H, W, 3, dtype=torch.float32))
|
||||
|
||||
out_device = comfy.model_management.intermediate_device()
|
||||
out_dtype = comfy.model_management.intermediate_dtype()
|
||||
bg_t = None if background is None else background.to(device=out_device, dtype=torch.float32)
|
||||
|
||||
if bg_t is not None and tuple(bg_t.shape[1:3]) != (H, W): # Match the background to the render resolution
|
||||
@ -1037,12 +1038,10 @@ class SAM3DBody_Render(io.ComfyNode):
|
||||
rainbow_tilt_z_deg=rainbow_tilt_z,
|
||||
person_brightness_falloff=person_palette_falloff,
|
||||
)
|
||||
frames_out.append(img)
|
||||
frames_out.append(img.to(device=out_device, dtype=out_dtype))
|
||||
pbar.update(1)
|
||||
|
||||
out_image = torch.stack(frames_out, dim=0)
|
||||
if out_image.device != out_device:
|
||||
out_image = out_image.to(out_device)
|
||||
return io.NodeOutput(out_image)
|
||||
|
||||
|
||||
|
||||
@ -40,16 +40,19 @@ def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
|
||||
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
|
||||
if packed is None:
|
||||
return None, None
|
||||
unpacked = unpack_masks(packed) # (N, K, Hm, Wm) bool
|
||||
N, K = unpacked.shape[:2]
|
||||
N, K = packed.shape[0], packed.shape[1]
|
||||
if N != B or K == 0:
|
||||
return None, None
|
||||
|
||||
device = comfy.model_management.get_torch_device()
|
||||
unpacked = unpack_masks(packed.to(device)) # (N, K, Hm, Wm) bool
|
||||
Hm, Wm = unpacked.shape[2], unpacked.shape[3]
|
||||
resized = F.interpolate(
|
||||
unpacked.float().reshape(N * K, 1, Hm, Wm),
|
||||
size=(H, W), mode="bilinear", align_corners=False,
|
||||
)
|
||||
arr = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W).cpu()
|
||||
arr_gpu = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W)
|
||||
arr = arr_gpu.cpu()
|
||||
|
||||
per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)]
|
||||
full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32)
|
||||
@ -57,8 +60,9 @@ def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
|
||||
for f in range(N):
|
||||
derived = []
|
||||
for k in range(K):
|
||||
b = _bbox_from_mask(arr[f, k])
|
||||
derived.append(b if b is not None else full_frame_bbox)
|
||||
# Erosion + argmax bbox on GPU; CPU max_pool2d over N*K full-res masks is slow.
|
||||
b = _bbox_from_mask(arr_gpu[f, k])
|
||||
derived.append(b.cpu() if b is not None else full_frame_bbox)
|
||||
per_frame_bboxes.append(torch.stack(derived, dim=0))
|
||||
return per_frame_bboxes, per_frame_masks
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user