mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-25 09:19:46 +08:00
Device fixes
This commit is contained in:
parent
5d643b15e4
commit
fb04969362
@ -908,6 +908,7 @@ class SAM3DBody_Render(io.ComfyNode):
|
|||||||
return io.NodeOutput(torch.zeros(1, H, W, 3, dtype=torch.float32))
|
return io.NodeOutput(torch.zeros(1, H, W, 3, dtype=torch.float32))
|
||||||
|
|
||||||
out_device = comfy.model_management.intermediate_device()
|
out_device = comfy.model_management.intermediate_device()
|
||||||
|
out_dtype = comfy.model_management.intermediate_dtype()
|
||||||
bg_t = None if background is None else background.to(device=out_device, dtype=torch.float32)
|
bg_t = None if background is None else background.to(device=out_device, dtype=torch.float32)
|
||||||
|
|
||||||
if bg_t is not None and tuple(bg_t.shape[1:3]) != (H, W): # Match the background to the render resolution
|
if bg_t is not None and tuple(bg_t.shape[1:3]) != (H, W): # Match the background to the render resolution
|
||||||
@ -1037,12 +1038,10 @@ class SAM3DBody_Render(io.ComfyNode):
|
|||||||
rainbow_tilt_z_deg=rainbow_tilt_z,
|
rainbow_tilt_z_deg=rainbow_tilt_z,
|
||||||
person_brightness_falloff=person_palette_falloff,
|
person_brightness_falloff=person_palette_falloff,
|
||||||
)
|
)
|
||||||
frames_out.append(img)
|
frames_out.append(img.to(device=out_device, dtype=out_dtype))
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
out_image = torch.stack(frames_out, dim=0)
|
out_image = torch.stack(frames_out, dim=0)
|
||||||
if out_image.device != out_device:
|
|
||||||
out_image = out_image.to(out_device)
|
|
||||||
return io.NodeOutput(out_image)
|
return io.NodeOutput(out_image)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -40,16 +40,19 @@ def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
|
|||||||
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
|
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
|
||||||
if packed is None:
|
if packed is None:
|
||||||
return None, None
|
return None, None
|
||||||
unpacked = unpack_masks(packed) # (N, K, Hm, Wm) bool
|
N, K = packed.shape[0], packed.shape[1]
|
||||||
N, K = unpacked.shape[:2]
|
|
||||||
if N != B or K == 0:
|
if N != B or K == 0:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
unpacked = unpack_masks(packed.to(device)) # (N, K, Hm, Wm) bool
|
||||||
Hm, Wm = unpacked.shape[2], unpacked.shape[3]
|
Hm, Wm = unpacked.shape[2], unpacked.shape[3]
|
||||||
resized = F.interpolate(
|
resized = F.interpolate(
|
||||||
unpacked.float().reshape(N * K, 1, Hm, Wm),
|
unpacked.float().reshape(N * K, 1, Hm, Wm),
|
||||||
size=(H, W), mode="bilinear", align_corners=False,
|
size=(H, W), mode="bilinear", align_corners=False,
|
||||||
)
|
)
|
||||||
arr = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W).cpu()
|
arr_gpu = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W)
|
||||||
|
arr = arr_gpu.cpu()
|
||||||
|
|
||||||
per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)]
|
per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)]
|
||||||
full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32)
|
full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32)
|
||||||
@ -57,8 +60,9 @@ def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
|
|||||||
for f in range(N):
|
for f in range(N):
|
||||||
derived = []
|
derived = []
|
||||||
for k in range(K):
|
for k in range(K):
|
||||||
b = _bbox_from_mask(arr[f, k])
|
# Erosion + argmax bbox on GPU; CPU max_pool2d over N*K full-res masks is slow.
|
||||||
derived.append(b if b is not None else full_frame_bbox)
|
b = _bbox_from_mask(arr_gpu[f, k])
|
||||||
|
derived.append(b.cpu() if b is not None else full_frame_bbox)
|
||||||
per_frame_bboxes.append(torch.stack(derived, dim=0))
|
per_frame_bboxes.append(torch.stack(derived, dim=0))
|
||||||
return per_frame_bboxes, per_frame_masks
|
return per_frame_bboxes, per_frame_masks
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user