mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Better RAM usage, reduce FILM VRAM peak
This commit is contained in:
parent
257c5312d9
commit
3cbd1d5f71
@ -105,6 +105,9 @@ class FeatureExtractor(nn.Module):
|
||||
if j <= i:
|
||||
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
|
||||
feature_pyramid.append(features)
|
||||
# Free sub-pyramids no longer needed by future levels
|
||||
if i >= self.sub_levels - 1:
|
||||
sub_pyramids[i - self.sub_levels + 1] = None
|
||||
return feature_pyramid
|
||||
|
||||
|
||||
@ -233,9 +236,11 @@ class FILMNet(nn.Module):
|
||||
fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels]
|
||||
bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels]
|
||||
|
||||
# Build warp targets and free full pyramids (only first fpl levels needed from here)
|
||||
fpl = self.fusion_pyramid_levels
|
||||
p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]),
|
||||
concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])]
|
||||
del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1
|
||||
|
||||
results = []
|
||||
dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype)
|
||||
@ -247,5 +252,7 @@ class FILMNet(nn.Module):
|
||||
bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp)
|
||||
aligned = [torch.cat([fw, bw, bf, ff], dim=1)
|
||||
for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)]
|
||||
del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled
|
||||
results.append(self.fuse(aligned))
|
||||
del aligned
|
||||
return torch.cat(results, dim=0)
|
||||
|
||||
@ -104,15 +104,19 @@ class FrameInterpolate(io.ComfyNode):
|
||||
dtype = model.model_dtype()
|
||||
inference_model = model.model
|
||||
|
||||
# BHWC -> BCHW
|
||||
frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device)
|
||||
_, C, H, W = frames.shape
|
||||
|
||||
# Pad to model's required alignment (RIFE needs 64, FILM handles any size)
|
||||
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
activation_mem = H * W * 3 * images.element_size() * 20
|
||||
model_management.free_memory(activation_mem, device)
|
||||
align = getattr(inference_model, "pad_align", 1)
|
||||
if align > 1:
|
||||
from comfy.ldm.common_dit import pad_to_patch_size
|
||||
frames = pad_to_patch_size(frames, (align, align), padding_mode="reflect")
|
||||
|
||||
# Prepare a single padded frame on device for determining output dimensions
|
||||
def prepare_frame(idx):
|
||||
frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device)
|
||||
if align > 1:
|
||||
from comfy.ldm.common_dit import pad_to_patch_size
|
||||
frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect")
|
||||
return frame
|
||||
|
||||
if torch_compile:
|
||||
for name, child in inference_model.named_children():
|
||||
@ -132,26 +136,29 @@ class FrameInterpolate(io.ComfyNode):
|
||||
|
||||
batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit)
|
||||
t_values = [t / multiplier for t in range(1, multiplier)]
|
||||
_, _, pH, pW = frames.shape
|
||||
|
||||
# Pre-allocate output tensor, pin for async GPU->CPU transfer
|
||||
out_dtype = model_management.intermediate_dtype()
|
||||
total_out_frames = total_pairs * multiplier + 1
|
||||
result = torch.empty((total_out_frames, C, pH, pW), dtype=dtype, device=offload_device)
|
||||
use_pin = model_management.pin_memory(result)
|
||||
result[0] = frames[0]
|
||||
result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device)
|
||||
result[0] = images[0].movedim(-1, 0).to(out_dtype)
|
||||
out_idx = 1
|
||||
|
||||
# Pre-compute timestep tensor on device
|
||||
# Pre-compute timestep tensor on device (padded dimensions needed)
|
||||
sample = prepare_frame(0)
|
||||
pH, pW = sample.shape[2], sample.shape[3]
|
||||
ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1)
|
||||
ts_full = ts_full.expand(-1, 1, pH, pW)
|
||||
del sample
|
||||
|
||||
multi_fn = getattr(inference_model, "forward_multi_timestep", None)
|
||||
feat_cache = {}
|
||||
prev_frame = None
|
||||
|
||||
try:
|
||||
for i in range(total_pairs):
|
||||
img0_single = frames[i:i + 1].to(device)
|
||||
img1_single = frames[i + 1:i + 2].to(device)
|
||||
img0_single = prev_frame if prev_frame is not None else prepare_frame(i)
|
||||
img1_single = prepare_frame(i + 1)
|
||||
prev_frame = img1_single
|
||||
|
||||
# Cache features: img1 of pair N becomes img0 of pair N+1
|
||||
feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single)
|
||||
@ -160,11 +167,17 @@ class FrameInterpolate(io.ComfyNode):
|
||||
|
||||
if multi_fn is not None:
|
||||
# Models with timestep-independent flow can compute it once for all timesteps
|
||||
mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache)
|
||||
result[out_idx:out_idx + num_interp].copy_(mids.to(dtype=dtype), non_blocking=use_pin)
|
||||
out_idx += num_interp
|
||||
pbar.update(num_interp)
|
||||
tqdm_bar.update(num_interp)
|
||||
try:
|
||||
mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache)
|
||||
result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype)
|
||||
out_idx += num_interp
|
||||
pbar.update(num_interp)
|
||||
tqdm_bar.update(num_interp)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
# Fall back to single-timestep calls
|
||||
model_management.soft_empty_cache()
|
||||
multi_fn = None
|
||||
continue
|
||||
else:
|
||||
j = 0
|
||||
while j < num_interp:
|
||||
@ -173,7 +186,7 @@ class FrameInterpolate(io.ComfyNode):
|
||||
img0 = img0_single.expand(b, -1, -1, -1)
|
||||
img1 = img1_single.expand(b, -1, -1, -1)
|
||||
mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache)
|
||||
result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin)
|
||||
result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype)
|
||||
out_idx += b
|
||||
pbar.update(b)
|
||||
tqdm_bar.update(b)
|
||||
@ -184,16 +197,13 @@ class FrameInterpolate(io.ComfyNode):
|
||||
batch = max(1, batch // 2)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
result[out_idx].copy_(frames[i + 1])
|
||||
result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype)
|
||||
out_idx += 1
|
||||
finally:
|
||||
tqdm_bar.close()
|
||||
if use_pin:
|
||||
model_management.synchronize()
|
||||
model_management.unpin_memory(result)
|
||||
|
||||
# Crop padding and BCHW -> BHWC
|
||||
result = result[:, :, :H, :W].movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype())
|
||||
# BCHW -> BHWC
|
||||
result = result.movedim(1, -1).clamp_(0.0, 1.0)
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user