mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
8fed4c1d41
@ -131,7 +131,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
|||||||
|
|
||||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||||
|
|
||||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||||
|
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||||
|
|
||||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
|
|
||||||
|
|||||||
@ -171,7 +171,10 @@ class Flux(nn.Module):
|
|||||||
pe = None
|
pe = None
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -215,7 +218,10 @@ class Flux(nn.Module):
|
|||||||
if self.params.global_modulation:
|
if self.params.global_modulation:
|
||||||
vec, _ = self.single_stream_modulation(vec_orig)
|
vec, _ = self.single_stream_modulation(vec_orig)
|
||||||
|
|
||||||
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
|
transformer_options["block_type"] = "single"
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
loaded_memory = loaded_model.model_loaded_memory()
|
loaded_memory = loaded_model.model_loaded_memory()
|
||||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||||
|
|
||||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||||
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
||||||
|
|
||||||
if lowvram_model_memory == 0:
|
if lowvram_model_memory == 0:
|
||||||
@ -1012,9 +1012,18 @@ def force_channels_last():
|
|||||||
|
|
||||||
|
|
||||||
STREAMS = {}
|
STREAMS = {}
|
||||||
NUM_STREAMS = 1
|
NUM_STREAMS = 0
|
||||||
if args.async_offload:
|
if args.async_offload is not None:
|
||||||
NUM_STREAMS = 2
|
NUM_STREAMS = args.async_offload
|
||||||
|
else:
|
||||||
|
# Enable by default on Nvidia
|
||||||
|
if is_nvidia():
|
||||||
|
NUM_STREAMS = 2
|
||||||
|
|
||||||
|
if args.disable_async_offload:
|
||||||
|
NUM_STREAMS = 0
|
||||||
|
|
||||||
|
if NUM_STREAMS > 0:
|
||||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||||
|
|
||||||
def current_stream(device):
|
def current_stream(device):
|
||||||
@ -1030,7 +1039,7 @@ def current_stream(device):
|
|||||||
stream_counters = {}
|
stream_counters = {}
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
if NUM_STREAMS <= 1:
|
if NUM_STREAMS == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if device in STREAMS:
|
if device in STREAMS:
|
||||||
|
|||||||
@ -148,6 +148,15 @@ class LowVramPatch:
|
|||||||
else:
|
else:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||||
|
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||||
|
|
||||||
|
def low_vram_patch_estimate_vram(model, key):
|
||||||
|
weight, set_func, convert_func = get_key_weight(model, key)
|
||||||
|
if weight is None:
|
||||||
|
return 0
|
||||||
|
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||||
|
|
||||||
def get_key_weight(model, key):
|
def get_key_weight(model, key):
|
||||||
set_func = None
|
set_func = None
|
||||||
convert_func = None
|
convert_func = None
|
||||||
@ -269,6 +278,9 @@ class ModelPatcher:
|
|||||||
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
||||||
self.model.current_weight_patches_uuid = None
|
self.model.current_weight_patches_uuid = None
|
||||||
|
|
||||||
|
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||||
|
self.model.model_offload_buffer_memory = 0
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
return self.size
|
return self.size
|
||||||
@ -662,7 +674,16 @@ class ModelPatcher:
|
|||||||
skip = True # skip random weights in non leaf modules
|
skip = True # skip random weights in non leaf modules
|
||||||
break
|
break
|
||||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
module_mem = comfy.model_management.module_size(m)
|
||||||
|
module_offload_mem = module_mem
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
if weight_key in self.patches:
|
||||||
|
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
||||||
|
if bias_key in self.patches:
|
||||||
|
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
||||||
|
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||||
return loading
|
return loading
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
@ -676,20 +697,22 @@ class ModelPatcher:
|
|||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
offloaded = []
|
offloaded = []
|
||||||
|
offload_buffer = 0
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
n = x[1]
|
module_offload_mem, module_mem, n, m, params = x
|
||||||
m = x[2]
|
|
||||||
params = x[3]
|
|
||||||
module_mem = x[0]
|
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
|
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
|
||||||
|
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if not lowvram_fits:
|
||||||
|
offload_buffer = potential_offload
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
lowvram_mem_counter += module_mem
|
lowvram_mem_counter += module_mem
|
||||||
@ -723,9 +746,11 @@ class ModelPatcher:
|
|||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
if full_load or lowvram_fits:
|
||||||
mem_counter += module_mem
|
mem_counter += module_mem
|
||||||
load_completely.append((module_mem, n, m, params))
|
load_completely.append((module_mem, n, m, params))
|
||||||
|
else:
|
||||||
|
offload_buffer = potential_offload
|
||||||
|
|
||||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
@ -766,7 +791,7 @@ class ModelPatcher:
|
|||||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
@ -778,6 +803,7 @@ class ModelPatcher:
|
|||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = mem_counter
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
|
self.model.model_offload_buffer_memory = offload_buffer
|
||||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||||
@ -831,6 +857,7 @@ class ModelPatcher:
|
|||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
self.model.model_offload_buffer_memory = 0
|
||||||
|
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if hasattr(m, "comfy_patched_weights"):
|
if hasattr(m, "comfy_patched_weights"):
|
||||||
@ -849,13 +876,14 @@ class ModelPatcher:
|
|||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
unload_list = self._load_list()
|
unload_list = self._load_list()
|
||||||
unload_list.sort()
|
unload_list.sort()
|
||||||
|
offload_buffer = self.model.model_offload_buffer_memory
|
||||||
|
|
||||||
for unload in unload_list:
|
for unload in unload_list:
|
||||||
if memory_to_free < memory_freed:
|
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||||
break
|
break
|
||||||
module_mem = unload[0]
|
module_offload_mem, module_mem, n, m, params = unload
|
||||||
n = unload[1]
|
|
||||||
m = unload[2]
|
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
|
||||||
params = unload[3]
|
|
||||||
|
|
||||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
@ -906,15 +934,18 @@ class ModelPatcher:
|
|||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
|
offload_buffer = max(offload_buffer, potential_offload)
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
|
self.model.model_offload_buffer_memory = offload_buffer
|
||||||
|
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
|
|||||||
@ -425,7 +425,8 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||||
return plain_tensor * scale
|
plain_tensor.mul_(scale)
|
||||||
|
return plain_tensor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_plain_tensors(cls, qtensor):
|
def get_plain_tensors(cls, qtensor):
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.32.9
|
comfyui-frontend-package==1.32.9
|
||||||
comfyui-workflow-templates==0.7.20
|
comfyui-workflow-templates==0.7.23
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user