diff --git a/comfy_extras/nodes_void.py b/comfy_extras/nodes_void.py index 4b923f4a2..301163269 100644 --- a/comfy_extras/nodes_void.py +++ b/comfy_extras/nodes_void.py @@ -1,6 +1,5 @@ import logging -import numpy as np import torch import comfy @@ -261,8 +260,9 @@ class VOIDWarpedNoise(io.ComfyNode): latent_h = height // 8 latent_w = width // 8 - vid = video[:length].cpu().numpy() - vid_uint8 = (vid * 255).clip(0, 255).astype(np.uint8) + # rp.get_noise_from_video expects uint8 numpy frames; everything + # downstream of the warp stays on torch. + vid_uint8 = (video[:length].clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() frames = [vid_uint8[i] for i in range(vid_uint8.shape[0])] frames = rp.resize_images_to_hold(frames, height=height, width=width) @@ -285,38 +285,30 @@ class VOIDWarpedNoise(io.ComfyNode): downscale_factor=round(FRAME * FLOW) * LATENT_SCALE, ) - warped_np = warp_output.numpy_noises # (T, H, W, C) - if warped_np.dtype == np.float16: - warped_np = warped_np.astype(np.float32) + # (T, H, W, C) → torch on intermediate device for torchified resize. + warped = torch.from_numpy(warp_output.numpy_noises).float() + device = comfy.model_management.intermediate_device() + warped = warped.to(device) - import cv2 + if warped.shape[0] != latent_t: + indices = torch.linspace(0, warped.shape[0] - 1, latent_t, + device=device).long() + warped = warped[indices] - if warped_np.shape[0] != latent_t: - indices = np.linspace(0, warped_np.shape[0] - 1, latent_t).astype(int) - warped_np = warped_np[indices] - - if warped_np.shape[1] != latent_h or warped_np.shape[2] != latent_w: - resized = [] - for t_idx in range(latent_t): - frame = warped_np[t_idx] - ch_resized = [ - cv2.resize(frame[:, :, c], (latent_w, latent_h), - interpolation=cv2.INTER_LINEAR) - for c in range(frame.shape[2]) - ] - resized.append(np.stack(ch_resized, axis=2)) - warped_np = np.stack(resized, axis=0) - - # (T, H, W, C) -> (B, C, T, H, W) - warped_tensor = torch.from_numpy( - warped_np.transpose(3, 0, 1, 2) - ).float().unsqueeze(0) + if warped.shape[1] != latent_h or warped.shape[2] != latent_w: + # (T, H, W, C) → (T, C, H, W) → bilinear resize → back + warped = warped.permute(0, 3, 1, 2) + warped = torch.nn.functional.interpolate( + warped, size=(latent_h, latent_w), + mode="bilinear", align_corners=False, + ) + warped = warped.permute(0, 2, 3, 1) + # (T, H, W, C) → (B, C, T, H, W) + warped_tensor = warped.permute(3, 0, 1, 2).unsqueeze(0) if batch_size > 1: warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1) - warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples": warped_tensor})