diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py index 552b78b8c..cf4f6e1e1 100644 --- a/comfy_extras/frame_interpolation_models/film_net.py +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -105,6 +105,9 @@ class FeatureExtractor(nn.Module): if j <= i: features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) feature_pyramid.append(features) + # Free sub-pyramids no longer needed by future levels + if i >= self.sub_levels - 1: + sub_pyramids[i - self.sub_levels + 1] = None return feature_pyramid @@ -233,9 +236,11 @@ class FILMNet(nn.Module): fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels] bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels] + # Build warp targets and free full pyramids (only first fpl levels needed from here) fpl = self.fusion_pyramid_levels p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]), concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])] + del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1 results = [] dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype) @@ -247,5 +252,7 @@ class FILMNet(nn.Module): bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp) aligned = [torch.cat([fw, bw, bf, ff], dim=1) for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)] + del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled results.append(self.fuse(aligned)) + del aligned return torch.cat(results, dim=0) diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index 723e9c85a..f0e1cf61f 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -104,15 +104,19 @@ class FrameInterpolate(io.ComfyNode): dtype = model.model_dtype() inference_model = model.model - # BHWC -> BCHW - frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device) - _, C, H, W = frames.shape - - # Pad to model's required alignment (RIFE needs 64, FILM handles any size) + # Free VRAM for inference activations (model weights + ~20x a single frame's worth) + H, W = images.shape[1], images.shape[2] + activation_mem = H * W * 3 * images.element_size() * 20 + model_management.free_memory(activation_mem, device) align = getattr(inference_model, "pad_align", 1) - if align > 1: - from comfy.ldm.common_dit import pad_to_patch_size - frames = pad_to_patch_size(frames, (align, align), padding_mode="reflect") + + # Prepare a single padded frame on device for determining output dimensions + def prepare_frame(idx): + frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device) + if align > 1: + from comfy.ldm.common_dit import pad_to_patch_size + frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect") + return frame if torch_compile: for name, child in inference_model.named_children(): @@ -132,26 +136,29 @@ class FrameInterpolate(io.ComfyNode): batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit) t_values = [t / multiplier for t in range(1, multiplier)] - _, _, pH, pW = frames.shape - # Pre-allocate output tensor, pin for async GPU->CPU transfer + out_dtype = model_management.intermediate_dtype() total_out_frames = total_pairs * multiplier + 1 - result = torch.empty((total_out_frames, C, pH, pW), dtype=dtype, device=offload_device) - use_pin = model_management.pin_memory(result) - result[0] = frames[0] + result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device) + result[0] = images[0].movedim(-1, 0).to(out_dtype) out_idx = 1 - # Pre-compute timestep tensor on device + # Pre-compute timestep tensor on device (padded dimensions needed) + sample = prepare_frame(0) + pH, pW = sample.shape[2], sample.shape[3] ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) ts_full = ts_full.expand(-1, 1, pH, pW) + del sample multi_fn = getattr(inference_model, "forward_multi_timestep", None) feat_cache = {} + prev_frame = None try: for i in range(total_pairs): - img0_single = frames[i:i + 1].to(device) - img1_single = frames[i + 1:i + 2].to(device) + img0_single = prev_frame if prev_frame is not None else prepare_frame(i) + img1_single = prepare_frame(i + 1) + prev_frame = img1_single # Cache features: img1 of pair N becomes img0 of pair N+1 feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single) @@ -160,11 +167,17 @@ class FrameInterpolate(io.ComfyNode): if multi_fn is not None: # Models with timestep-independent flow can compute it once for all timesteps - mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) - result[out_idx:out_idx + num_interp].copy_(mids.to(dtype=dtype), non_blocking=use_pin) - out_idx += num_interp - pbar.update(num_interp) - tqdm_bar.update(num_interp) + try: + mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) + result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype) + out_idx += num_interp + pbar.update(num_interp) + tqdm_bar.update(num_interp) + except model_management.OOM_EXCEPTION: + # Fall back to single-timestep calls + model_management.soft_empty_cache() + multi_fn = None + continue else: j = 0 while j < num_interp: @@ -173,7 +186,7 @@ class FrameInterpolate(io.ComfyNode): img0 = img0_single.expand(b, -1, -1, -1) img1 = img1_single.expand(b, -1, -1, -1) mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache) - result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin) + result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype) out_idx += b pbar.update(b) tqdm_bar.update(b) @@ -184,16 +197,13 @@ class FrameInterpolate(io.ComfyNode): batch = max(1, batch // 2) model_management.soft_empty_cache() - result[out_idx].copy_(frames[i + 1]) + result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype) out_idx += 1 finally: tqdm_bar.close() - if use_pin: - model_management.synchronize() - model_management.unpin_memory(result) - # Crop padding and BCHW -> BHWC - result = result[:, :, :H, :W].movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype()) + # BCHW -> BHWC + result = result.movedim(1, -1).clamp_(0.0, 1.0) return io.NodeOutput(result)