From 9b52e24430da496b48cd325b58e8177a2174b277 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 20 May 2026 17:46:21 +0300 Subject: [PATCH] added dynamic chunking --- comfy/ldm/trellis2/flexgemm.py | 54 ++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/comfy/ldm/trellis2/flexgemm.py b/comfy/ldm/trellis2/flexgemm.py index eb08d2970..416e322ab 100644 --- a/comfy/ldm/trellis2/flexgemm.py +++ b/comfy/ldm/trellis2/flexgemm.py @@ -107,6 +107,59 @@ def build_submanifold_neighbor_map( return neighbor +def get_recommended_chunk_mem( + device=None, + safety_fraction: float = 0.4, + min_gb: float = 0.25, + max_gb: float = 8.0, +): + + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + device = torch.device(device) + + if device.type == 'cuda': + try: + idx = device.index if device.index is not None else 0 + free_bytes, total_bytes = torch.cuda.mem_get_info(idx) + free_gb = free_bytes / (1024 ** 3) + total_gb = total_bytes / (1024 ** 3) + + recommended = free_gb * safety_fraction + result = max(min_gb, min(recommended, max_gb)) + return result + + except Exception: + try: + idx = device.index if device.index is not None else 0 + total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024 ** 3) + except Exception: + total_gb = 16.0 + + if total_gb < 12: + result = 0.5 + elif total_gb < 16: + result = 0.75 + elif total_gb < 24: + result = 1.0 + elif total_gb < 32: + result = 2.0 + elif total_gb < 48: + result = 4.0 + else: + result = 6.0 + return result + + else: + try: + import psutil + avail_gb = psutil.virtual_memory().available / (1024 ** 3) + recommended = avail_gb * safety_fraction + result = max(min_gb, min(recommended, max_gb)) + return result + except ImportError: + return min_gb def sparse_submanifold_conv3d( feats: torch.Tensor, @@ -133,6 +186,7 @@ def sparse_submanifold_conv3d( V = Kw * Kh * Kd device = feats.device sentinel = -1 + max_chunk_mem_gb = get_recommended_chunk_mem(device) if neighbor_cache is None: b_stride = W * H * D