Better RAM usage, reduce FILM VRAM peak

This commit is contained in:
kijai 2026-04-04 17:28:37 +03:00
parent 257c5312d9
commit 3cbd1d5f71
2 changed files with 45 additions and 28 deletions

View File

@ -105,6 +105,9 @@ class FeatureExtractor(nn.Module):
if j <= i:
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
feature_pyramid.append(features)
# Free sub-pyramids no longer needed by future levels
if i >= self.sub_levels - 1:
sub_pyramids[i - self.sub_levels + 1] = None
return feature_pyramid
@ -233,9 +236,11 @@ class FILMNet(nn.Module):
fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels]
bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels]
# Build warp targets and free full pyramids (only first fpl levels needed from here)
fpl = self.fusion_pyramid_levels
p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]),
concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])]
del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1
results = []
dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype)
@ -247,5 +252,7 @@ class FILMNet(nn.Module):
bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp)
aligned = [torch.cat([fw, bw, bf, ff], dim=1)
for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)]
del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled
results.append(self.fuse(aligned))
del aligned
return torch.cat(results, dim=0)

View File

@ -104,15 +104,19 @@ class FrameInterpolate(io.ComfyNode):
dtype = model.model_dtype()
inference_model = model.model
# BHWC -> BCHW
frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device)
_, C, H, W = frames.shape
# Pad to model's required alignment (RIFE needs 64, FILM handles any size)
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
H, W = images.shape[1], images.shape[2]
activation_mem = H * W * 3 * images.element_size() * 20
model_management.free_memory(activation_mem, device)
align = getattr(inference_model, "pad_align", 1)
if align > 1:
from comfy.ldm.common_dit import pad_to_patch_size
frames = pad_to_patch_size(frames, (align, align), padding_mode="reflect")
# Prepare a single padded frame on device for determining output dimensions
def prepare_frame(idx):
frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device)
if align > 1:
from comfy.ldm.common_dit import pad_to_patch_size
frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect")
return frame
if torch_compile:
for name, child in inference_model.named_children():
@ -132,26 +136,29 @@ class FrameInterpolate(io.ComfyNode):
batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit)
t_values = [t / multiplier for t in range(1, multiplier)]
_, _, pH, pW = frames.shape
# Pre-allocate output tensor, pin for async GPU->CPU transfer
out_dtype = model_management.intermediate_dtype()
total_out_frames = total_pairs * multiplier + 1
result = torch.empty((total_out_frames, C, pH, pW), dtype=dtype, device=offload_device)
use_pin = model_management.pin_memory(result)
result[0] = frames[0]
result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device)
result[0] = images[0].movedim(-1, 0).to(out_dtype)
out_idx = 1
# Pre-compute timestep tensor on device
# Pre-compute timestep tensor on device (padded dimensions needed)
sample = prepare_frame(0)
pH, pW = sample.shape[2], sample.shape[3]
ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1)
ts_full = ts_full.expand(-1, 1, pH, pW)
del sample
multi_fn = getattr(inference_model, "forward_multi_timestep", None)
feat_cache = {}
prev_frame = None
try:
for i in range(total_pairs):
img0_single = frames[i:i + 1].to(device)
img1_single = frames[i + 1:i + 2].to(device)
img0_single = prev_frame if prev_frame is not None else prepare_frame(i)
img1_single = prepare_frame(i + 1)
prev_frame = img1_single
# Cache features: img1 of pair N becomes img0 of pair N+1
feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single)
@ -160,11 +167,17 @@ class FrameInterpolate(io.ComfyNode):
if multi_fn is not None:
# Models with timestep-independent flow can compute it once for all timesteps
mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache)
result[out_idx:out_idx + num_interp].copy_(mids.to(dtype=dtype), non_blocking=use_pin)
out_idx += num_interp
pbar.update(num_interp)
tqdm_bar.update(num_interp)
try:
mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache)
result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype)
out_idx += num_interp
pbar.update(num_interp)
tqdm_bar.update(num_interp)
except model_management.OOM_EXCEPTION:
# Fall back to single-timestep calls
model_management.soft_empty_cache()
multi_fn = None
continue
else:
j = 0
while j < num_interp:
@ -173,7 +186,7 @@ class FrameInterpolate(io.ComfyNode):
img0 = img0_single.expand(b, -1, -1, -1)
img1 = img1_single.expand(b, -1, -1, -1)
mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache)
result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin)
result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype)
out_idx += b
pbar.update(b)
tqdm_bar.update(b)
@ -184,16 +197,13 @@ class FrameInterpolate(io.ComfyNode):
batch = max(1, batch // 2)
model_management.soft_empty_cache()
result[out_idx].copy_(frames[i + 1])
result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype)
out_idx += 1
finally:
tqdm_bar.close()
if use_pin:
model_management.synchronize()
model_management.unpin_memory(result)
# Crop padding and BCHW -> BHWC
result = result[:, :, :H, :W].movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype())
# BCHW -> BHWC
result = result.movedim(1, -1).clamp_(0.0, 1.0)
return io.NodeOutput(result)