removed cuda graphs + changes to loudness.py

This commit is contained in:
Yousef Rafat 2025-10-24 18:09:30 +03:00
parent 07cfed84ff
commit 203849f6cc
3 changed files with 19 additions and 139 deletions

View File

@ -9,40 +9,10 @@ from enum import Enum
from dataclasses import dataclass, fields
from transformers.cache_utils import StaticCache, DynamicCache, Cache
from typing import Optional, Union, Any
from comfy.model_management import get_free_memory, minimum_inference_memory
import comfy.model_management
NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache }
def estimate_autoregressive_vram(
num_layers: int,
hidden_dim: int,
max_seq_len: int,
batch_size: int = 1,
dtype = torch.float16,
intermediate_factor: float = 4.0,
device = torch.device('cuda')
) -> bool:
dtype_size = torch.finfo(dtype).bits // 8
kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size
# we only calculate hidden states in cuda graphs, so we don't care about the output logits
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
# rough calculation for activation sizes
intermediate_bytes = intermediate_factor * output_bytes
total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes
# get vram info
free_vram = get_free_memory(device)
minimum_vram = minimum_inference_memory()
enough_vram = free_vram - minimum_vram >= total_estimated
return enough_vram
class TopKLogits:
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -282,21 +252,8 @@ class AutoRegressiveGeneration:
} if self.model.cache_implementation == "static" and self.model.use_kv_buckets else None
enough_vram = estimate_autoregressive_vram(
self.model.num_hidden_layers, self.model.hidden_dim, self.model.max_seq_len, dtype = self.dtype, device = device
)
# cuda graphs only help if input shapes are constant
if (
device == "cuda"
and hasattr(model, "capture_model")
and self.model.cache_implementation == "static"
and self.model.use_kv_buckets
and enough_vram
):
self.model.capture_model(self.kv_caches.values())
else:
self.model.generation_config.is_using_cuda_graphs = False
# for now
self.model.generation_config.is_using_cuda_graphs = False
@torch.inference_mode()
def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0,

View File

@ -1,59 +0,0 @@
import torch
import torch.nn as nn
from typing import Optional, Dict
import gc
_NUM_WARMUP_ITERS = 2
class CUDAGraphRunner(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
@property
def graph(self):
assert self._graph is not None
return self._graph
def capture(self, *args, **kwargs):
assert self._graph is None
for _ in range(_NUM_WARMUP_ITERS):
self.model(*args, **kwargs)
torch.cuda.synchronize()
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool = kwargs.get("memory_pool", None), stream = kwargs.get("stream", None)):
last_hidden_states = self.model(*args, **kwargs)
gc.collect()
torch.cuda.synchronize()
self.input_buffers = {
"args": [arg for arg in args if isinstance(arg, torch.Tensor)],
"kwargs": {k: v for k, v in kwargs.items() if isinstance(v, torch.Tensor)},
}
self.output_buffers = {
"hidden_states": last_hidden_states
}
def forward(self, *args, **kwargs):
for i, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
self.input_buffers["args"][i].copy_(arg, non_blocking=True)
for k, v in kwargs.items():
if k in self.input_buffers["kwargs"] and isinstance(v, torch.Tensor):
self.input_buffers["kwargs"][k].copy_(v, non_blocking=True)
self.graph.replay()
return self.output_buffers["hidden_states"]

View File

@ -88,6 +88,7 @@ def fft_conv1d(
class IIRfilter(object):
def __init__(self, G, Q, fc, rate, filter_type, passband_gain=1.0):
G, Q, fc, rate = [t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in (G, Q, fc, rate)]
self.G = G
self.Q = Q
self.fc = fc
@ -98,26 +99,26 @@ class IIRfilter(object):
def generate_coefficients(self):
A = 10**(self.G/40.0)
w0 = 2.0 * np.pi * (self.fc / self.rate)
alpha = np.sin(w0) / (2.0 * self.Q)
w0 = 2.0 * torch.pi * (self.fc / self.rate)
alpha = torch.sin(w0) / (2.0 * self.Q)
if self.filter_type == 'high_shelf':
b0 = A * ( (A+1) + (A-1) * np.cos(w0) + 2 * np.sqrt(A) * alpha )
b1 = -2 * A * ( (A-1) + (A+1) * np.cos(w0) )
b2 = A * ( (A+1) + (A-1) * np.cos(w0) - 2 * np.sqrt(A) * alpha )
a0 = (A+1) - (A-1) * np.cos(w0) + 2 * np.sqrt(A) * alpha
a1 = 2 * ( (A-1) - (A+1) * np.cos(w0) )
a2 = (A+1) - (A-1) * np.cos(w0) - 2 * np.sqrt(A) * alpha
b0 = A * ( (A+1) + (A-1) * torch.cos(w0) + 2 * torch.sqrt(A) * alpha )
b1 = -2 * A * ( (A-1) + (A+1) * torch.cos(w0) )
b2 = A * ( (A+1) + (A-1) * torch.cos(w0) - 2 * torch.sqrt(A) * alpha )
a0 = (A+1) - (A-1) * torch.cos(w0) + 2 * torch.sqrt(A) * alpha
a1 = 2 * ( (A-1) - (A+1) * torch.cos(w0) )
a2 = (A+1) - (A-1) * torch.cos(w0) - 2 * torch.sqrt(A) * alpha
elif self.filter_type == 'high_pass':
b0 = (1 + np.cos(w0))/2
b1 = -(1 + np.cos(w0))
b2 = (1 + np.cos(w0))/2
b0 = (1 + torch.cos(w0))/2
b1 = -(1 + torch.cos(w0))
b2 = (1 + torch.cos(w0))/2
a0 = 1 + alpha
a1 = -2 * np.cos(w0)
a1 = -2 * torch.cos(w0)
a2 = 1 - alpha
return np.array([b0, b1, b2])/a0, np.array([a0, a1, a2])/a0
return torch.tensor([b0, b1, b2])/a0, torch.tensor([a0, a1, a2])/a0
def apply_filter(self, data):
return self.passband_gain * scipy.signal.lfilter(self.b, self.a, data)
@ -160,14 +161,14 @@ class Meter(torch.nn.Module):
for i, (_, filter_stage) in enumerate(self._filters.items()):
b, a = filter_stage.b_and_a
firs[i] = scipy.signal.lfilter(b, a, impulse)
firs[i] = scipy.signal.lfilter(b.numpy(), a.numpy(), impulse)
firs = torch.from_numpy(firs[..., ::-1].copy()).float()
self.register_buffer("firs", firs)
self.register_buffer("passband_gain", passband_gain)
def apply_filter_gpu(self, data: torch.Tensor):
def apply_filter_fir(self, data: torch.Tensor):
# Data is of shape (nb, nch, nt)
# Reshape to (nb*nch, 1, nt)
@ -189,27 +190,8 @@ class Meter(torch.nn.Module):
data = data[:, :nt, :]
return data
def apply_filter_cpu(self, data: torch.Tensor):
for _, filter_stage in self._filters.items():
passband_gain = filter_stage.passband_gain
b, a = filter_stage.b_and_a
a_coeffs = torch.from_numpy(a).float().to(data.device)
b_coeffs = torch.from_numpy(b).float().to(data.device)
_data = data.permute(0, 2, 1)
filtered = torchaudio.functional.lfilter(
_data, a_coeffs, b_coeffs, clamp=False
)
data = passband_gain * filtered.permute(0, 2, 1)
return data
def apply_filter(self, data: torch.Tensor):
if data.is_cuda or self.use_fir:
data = self.apply_filter_gpu(data)
else:
data = self.apply_filter_cpu(data)
return data
return self.apply_filter_fir(data)
def forward(self, data: torch.Tensor):
return self.integrated_loudness(data)