mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 17:20:51 +08:00
60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from typing import Optional, Dict
|
|
import gc
|
|
|
|
_NUM_WARMUP_ITERS = 2
|
|
|
|
class CUDAGraphRunner(nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
self.input_buffers: Dict[str, torch.Tensor] = {}
|
|
self.output_buffers: Dict[str, torch.Tensor] = {}
|
|
|
|
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
|
|
|
@property
|
|
def graph(self):
|
|
assert self._graph is not None
|
|
return self._graph
|
|
|
|
def capture(self, *args, **kwargs):
|
|
assert self._graph is None
|
|
|
|
for _ in range(_NUM_WARMUP_ITERS):
|
|
self.model(*args, **kwargs)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
self._graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(self._graph, pool = kwargs.get("memory_pool", None), stream = kwargs.get("stream", None)):
|
|
last_hidden_states = self.model(*args, **kwargs)
|
|
gc.collect()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
self.input_buffers = {
|
|
"args": [arg for arg in args if isinstance(arg, torch.Tensor)],
|
|
"kwargs": {k: v for k, v in kwargs.items() if isinstance(v, torch.Tensor)},
|
|
}
|
|
|
|
self.output_buffers = {
|
|
"hidden_states": last_hidden_states
|
|
}
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
for i, arg in enumerate(args):
|
|
if isinstance(arg, torch.Tensor):
|
|
self.input_buffers["args"][i].copy_(arg, non_blocking=True)
|
|
|
|
for k, v in kwargs.items():
|
|
if k in self.input_buffers["kwargs"] and isinstance(v, torch.Tensor):
|
|
self.input_buffers["kwargs"][k].copy_(v, non_blocking=True)
|
|
|
|
self.graph.replay()
|
|
|
|
return self.output_buffers["hidden_states"]
|