Drop cv2 & numpy dependency, run VOIDWarpedNoise with torch.
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

This commit is contained in:
Talmaj Marinc 2026-04-16 21:45:35 +02:00
parent 7609381243
commit b68042b5c4

View File

@ -1,6 +1,5 @@
import logging
import numpy as np
import torch
import comfy
@ -261,8 +260,9 @@ class VOIDWarpedNoise(io.ComfyNode):
latent_h = height // 8
latent_w = width // 8
vid = video[:length].cpu().numpy()
vid_uint8 = (vid * 255).clip(0, 255).astype(np.uint8)
# rp.get_noise_from_video expects uint8 numpy frames; everything
# downstream of the warp stays on torch.
vid_uint8 = (video[:length].clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
frames = [vid_uint8[i] for i in range(vid_uint8.shape[0])]
frames = rp.resize_images_to_hold(frames, height=height, width=width)
@ -285,38 +285,30 @@ class VOIDWarpedNoise(io.ComfyNode):
downscale_factor=round(FRAME * FLOW) * LATENT_SCALE,
)
warped_np = warp_output.numpy_noises # (T, H, W, C)
if warped_np.dtype == np.float16:
warped_np = warped_np.astype(np.float32)
# (T, H, W, C) → torch on intermediate device for torchified resize.
warped = torch.from_numpy(warp_output.numpy_noises).float()
device = comfy.model_management.intermediate_device()
warped = warped.to(device)
import cv2
if warped.shape[0] != latent_t:
indices = torch.linspace(0, warped.shape[0] - 1, latent_t,
device=device).long()
warped = warped[indices]
if warped_np.shape[0] != latent_t:
indices = np.linspace(0, warped_np.shape[0] - 1, latent_t).astype(int)
warped_np = warped_np[indices]
if warped_np.shape[1] != latent_h or warped_np.shape[2] != latent_w:
resized = []
for t_idx in range(latent_t):
frame = warped_np[t_idx]
ch_resized = [
cv2.resize(frame[:, :, c], (latent_w, latent_h),
interpolation=cv2.INTER_LINEAR)
for c in range(frame.shape[2])
]
resized.append(np.stack(ch_resized, axis=2))
warped_np = np.stack(resized, axis=0)
# (T, H, W, C) -> (B, C, T, H, W)
warped_tensor = torch.from_numpy(
warped_np.transpose(3, 0, 1, 2)
).float().unsqueeze(0)
if warped.shape[1] != latent_h or warped.shape[2] != latent_w:
# (T, H, W, C) → (T, C, H, W) → bilinear resize → back
warped = warped.permute(0, 3, 1, 2)
warped = torch.nn.functional.interpolate(
warped, size=(latent_h, latent_w),
mode="bilinear", align_corners=False,
)
warped = warped.permute(0, 2, 3, 1)
# (T, H, W, C) → (B, C, T, H, W)
warped_tensor = warped.permute(3, 0, 1, 2).unsqueeze(0)
if batch_size > 1:
warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1)
warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": warped_tensor})