mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
initial RIFE support
This commit is contained in:
parent
0c63b4f6e3
commit
a859152817
183
comfy_extras/nodes_frame_interpolation.py
Normal file
183
comfy_extras/nodes_frame_interpolation.py
Normal file
@ -0,0 +1,183 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy import model_management
|
||||
from comfy_extras.rife_model.ifnet import IFNet, detect_rife_config, clear_warp_cache
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
FrameInterpolationModel = io.Custom("FRAME_INTERPOLATION_MODEL")
|
||||
|
||||
|
||||
class FrameInterpolationModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FrameInterpolationModelLoader",
|
||||
display_name="Load Frame Interpolation Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation")),
|
||||
],
|
||||
outputs=[
|
||||
FrameInterpolationModel.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
|
||||
# Strip common prefixes (DataParallel, RIFE model wrapper)
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""})
|
||||
|
||||
# Convert blockN.xxx keys to blocks.N.xxx if needed
|
||||
key_map = {}
|
||||
for k in sd:
|
||||
for i in range(5):
|
||||
prefix = f"block{i}."
|
||||
if k.startswith(prefix):
|
||||
key_map[k] = f"blocks.{i}.{k[len(prefix):]}"
|
||||
if key_map:
|
||||
new_sd = {}
|
||||
for k, v in sd.items():
|
||||
new_sd[key_map.get(k, k)] = v
|
||||
sd = new_sd
|
||||
|
||||
# Filter out training-only keys (teacher distillation, timestamp calibration)
|
||||
sd = {k: v for k, v in sd.items()
|
||||
if not k.startswith(("teacher.", "caltime."))}
|
||||
|
||||
head_ch, channels = detect_rife_config(sd)
|
||||
model = IFNet(head_ch=head_ch, channels=channels)
|
||||
model.load_state_dict(sd)
|
||||
# RIFE is a small pixel-space model similar to VAE, bf16 produces artifacts due to low mantissa precision
|
||||
dtype = model_management.vae_dtype(device=model_management.get_torch_device(),
|
||||
allowed_dtypes=[torch.float16, torch.float32])
|
||||
model.eval().to(dtype)
|
||||
patcher = comfy.model_patcher.ModelPatcher(
|
||||
model,
|
||||
load_device=model_management.get_torch_device(),
|
||||
offload_device=model_management.unet_offload_device(),
|
||||
)
|
||||
return io.NodeOutput(patcher)
|
||||
|
||||
|
||||
class FrameInterpolate(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FrameInterpolate",
|
||||
display_name="Frame Interpolate",
|
||||
category="image/video",
|
||||
search_aliases=["rife", "frame interpolation", "slow motion", "interpolate frames"],
|
||||
inputs=[
|
||||
FrameInterpolationModel.Input("model"),
|
||||
io.Image.Input("images"),
|
||||
io.Int.Input("multiplier", default=2, min=2, max=16),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, images, multiplier) -> io.NodeOutput:
|
||||
offload_device = model_management.intermediate_device()
|
||||
|
||||
num_frames = images.shape[0]
|
||||
if num_frames < 2 or multiplier < 2:
|
||||
return io.NodeOutput(images)
|
||||
|
||||
model_management.load_model_gpu(model)
|
||||
device = model.load_device
|
||||
inference_model = model.model
|
||||
dtype = model.model_dtype()
|
||||
|
||||
# BHWC -> BCHW
|
||||
frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device)
|
||||
_, C, H, W = frames.shape
|
||||
|
||||
# Pad to multiple of 64
|
||||
pad_h = (64 - H % 64) % 64
|
||||
pad_w = (64 - W % 64) % 64
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
frames = F.pad(frames, (0, pad_w, 0, pad_h), mode="reflect")
|
||||
|
||||
# Count total interpolation passes for progress bar
|
||||
total_pairs = num_frames - 1
|
||||
num_interp = multiplier - 1
|
||||
total_steps = total_pairs * num_interp
|
||||
pbar = comfy.utils.ProgressBar(total_steps)
|
||||
tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation")
|
||||
|
||||
batch = num_interp
|
||||
t_values = [t / multiplier for t in range(1, multiplier)]
|
||||
_, _, pH, pW = frames.shape
|
||||
|
||||
# Pre-allocate output tensor, pin for async GPU->CPU transfer
|
||||
total_out_frames = total_pairs * multiplier + 1
|
||||
result = torch.empty((total_out_frames, C, pH, pW), dtype=dtype, device=offload_device)
|
||||
use_pin = model_management.pin_memory(result)
|
||||
result[0] = frames[0]
|
||||
out_idx = 1
|
||||
|
||||
# Pre-compute timestep tensor on device
|
||||
ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1)
|
||||
ts_full = ts_full.expand(-1, 1, pH, pW)
|
||||
|
||||
try:
|
||||
for i in range(total_pairs):
|
||||
img0_single = frames[i:i + 1].to(device)
|
||||
img1_single = frames[i + 1:i + 2].to(device)
|
||||
|
||||
j = 0
|
||||
while j < num_interp:
|
||||
b = min(batch, num_interp - j)
|
||||
try:
|
||||
img0 = img0_single.expand(b, -1, -1, -1)
|
||||
img1 = img1_single.expand(b, -1, -1, -1)
|
||||
mids = inference_model(img0, img1, timestep=ts_full[j:j + b])
|
||||
result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin)
|
||||
out_idx += b
|
||||
pbar.update(b)
|
||||
tqdm_bar.update(b)
|
||||
j += b
|
||||
except model_management.OOM_EXCEPTION:
|
||||
if batch <= 1:
|
||||
raise
|
||||
batch = max(1, batch // 2)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
result[out_idx].copy_(frames[i + 1])
|
||||
out_idx += 1
|
||||
finally:
|
||||
tqdm_bar.close()
|
||||
clear_warp_cache()
|
||||
if use_pin:
|
||||
model_management.synchronize()
|
||||
model_management.unpin_memory(result)
|
||||
|
||||
# Crop padding and BCHW -> BHWC
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
result = result[:, :, :H, :W]
|
||||
result = result.movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype())
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class FrameInterpolationExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
FrameInterpolationModelLoader,
|
||||
FrameInterpolate,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> FrameInterpolationExtension:
|
||||
return FrameInterpolationExtension()
|
||||
168
comfy_extras/rife_model/ifnet.py
Normal file
168
comfy_extras/rife_model/ifnet.py
Normal file
@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
_warp_grid_cache = {}
|
||||
_WARP_GRID_CACHE_MAX = 4
|
||||
|
||||
|
||||
def clear_warp_cache():
|
||||
_warp_grid_cache.clear()
|
||||
|
||||
|
||||
def warp(img, flow):
|
||||
B, _, H, W = img.shape
|
||||
dtype = img.dtype
|
||||
img = img.float()
|
||||
flow = flow.float()
|
||||
cache_key = (H, W, flow.device)
|
||||
if cache_key not in _warp_grid_cache:
|
||||
if len(_warp_grid_cache) >= _WARP_GRID_CACHE_MAX:
|
||||
_warp_grid_cache.pop(next(iter(_warp_grid_cache)))
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(-1.0, 1.0, H, device=flow.device, dtype=torch.float32),
|
||||
torch.linspace(-1.0, 1.0, W, device=flow.device, dtype=torch.float32),
|
||||
indexing="ij",
|
||||
)
|
||||
_warp_grid_cache[cache_key] = torch.stack((grid_x, grid_y), dim=0).unsqueeze(0)
|
||||
grid = _warp_grid_cache[cache_key].expand(B, -1, -1, -1)
|
||||
flow_norm = torch.cat([
|
||||
flow[:, 0:1] / ((W - 1) / 2),
|
||||
flow[:, 1:2] / ((H - 1) / 2),
|
||||
], dim=1)
|
||||
grid = (grid + flow_norm).permute(0, 2, 3, 1)
|
||||
return F.grid_sample(img, grid, mode="bilinear", padding_mode="border", align_corners=True).to(dtype)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
def __init__(self, out_ch=4, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype)
|
||||
self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
|
||||
self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
|
||||
self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype)
|
||||
self.relu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.cnn0(x))
|
||||
x = self.relu(self.cnn1(x))
|
||||
x = self.relu(self.cnn2(x))
|
||||
return self.cnn3(x)
|
||||
|
||||
|
||||
class ResConv(nn.Module):
|
||||
def __init__(self, c, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype)
|
||||
self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype))
|
||||
self.relu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(torch.addcmul(x, self.conv(x), self.beta))
|
||||
|
||||
|
||||
class IFBlock(nn.Module):
|
||||
def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Sequential(
|
||||
nn.Sequential(
|
||||
operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
),
|
||||
nn.Sequential(
|
||||
operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
),
|
||||
)
|
||||
self.convblock = nn.Sequential(
|
||||
*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8))
|
||||
)
|
||||
self.lastconv = nn.Sequential(
|
||||
operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
def forward(self, x, flow=None, scale=1):
|
||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear")
|
||||
if flow is not None:
|
||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale)
|
||||
x = torch.cat((x, flow), 1)
|
||||
feat = self.conv0(x)
|
||||
feat = self.convblock(feat)
|
||||
tmp = self.lastconv(feat)
|
||||
tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear")
|
||||
flow = tmp[:, :4] * scale
|
||||
mask = tmp[:, 4:5]
|
||||
feat = tmp[:, 5:]
|
||||
return flow, mask, feat
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self, head_ch=4, channels=None, device=None, dtype=None, operations=ops):
|
||||
super().__init__()
|
||||
if channels is None:
|
||||
channels = [192, 128, 96, 64, 32]
|
||||
self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations)
|
||||
block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4
|
||||
self.blocks = nn.ModuleList([
|
||||
IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations)
|
||||
for i in range(5)
|
||||
])
|
||||
self.scale_list = [16, 8, 4, 2, 1]
|
||||
|
||||
def get_dtype(self):
|
||||
return self.encode.cnn0.weight.dtype
|
||||
|
||||
def forward(self, img0, img1, timestep=0.5):
|
||||
img0 = img0.clamp(0.0, 1.0)
|
||||
img1 = img1.clamp(0.0, 1.0)
|
||||
if not isinstance(timestep, torch.Tensor):
|
||||
timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]),
|
||||
timestep, device=img0.device, dtype=img0.dtype)
|
||||
f0 = self.encode(img0)
|
||||
f1 = self.encode(img1)
|
||||
flow = mask = feat = None
|
||||
warped_img0 = img0
|
||||
warped_img1 = img1
|
||||
for i, block in enumerate(self.blocks):
|
||||
if flow is None:
|
||||
flow, mask, feat = block(
|
||||
torch.cat((img0, img1, f0, f1, timestep), 1),
|
||||
None, scale=self.scale_list[i],
|
||||
)
|
||||
else:
|
||||
wf0 = warp(f0, flow[:, :2])
|
||||
wf1 = warp(f1, flow[:, 2:4])
|
||||
fd, mask, feat = block(
|
||||
torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1),
|
||||
flow, scale=self.scale_list[i],
|
||||
)
|
||||
flow = flow.add_(fd)
|
||||
warped_img0 = warp(img0, flow[:, :2])
|
||||
warped_img1 = warp(img1, flow[:, 2:4])
|
||||
mask = torch.sigmoid(mask)
|
||||
return torch.lerp(warped_img1, warped_img0, mask)
|
||||
|
||||
|
||||
def detect_rife_config(state_dict):
|
||||
# Determine head output channels from encode.cnn3 (ConvTranspose2d)
|
||||
# ConvTranspose2d weight shape is (in_ch, out_ch, kH, kW)
|
||||
head_ch = state_dict["encode.cnn3.weight"].shape[1]
|
||||
|
||||
# Read per-block channel widths from conv0 second layer output channels
|
||||
# conv0 is Sequential(Sequential(Conv2d, ReLU), Sequential(Conv2d, ReLU))
|
||||
# conv0.1.0.weight shape is (c, c//2, 3, 3)
|
||||
channels = []
|
||||
for i in range(5):
|
||||
key = f"blocks.{i}.conv0.1.0.weight"
|
||||
if key in state_dict:
|
||||
channels.append(state_dict[key].shape[0])
|
||||
|
||||
if len(channels) != 5:
|
||||
raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}")
|
||||
|
||||
return head_ch, channels
|
||||
@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
|
||||
|
||||
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user