Compare commits

..

No commits in common. "f17251bec65b5760cfedec29eace7d77f4b35130" and "eaf68c9b5bbfbcdac8988741f3948678c9465c1d" have entirely different histories.

3 changed files with 17 additions and 54 deletions

View File

@ -171,10 +171,7 @@ class Flux(nn.Module):
pe = None
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -218,10 +215,7 @@ class Flux(nn.Module):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0:
@ -1012,7 +1012,7 @@ def force_channels_last():
STREAMS = {}
NUM_STREAMS = 0
NUM_STREAMS = 1
if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
@ -1030,7 +1030,7 @@ def current_stream(device):
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS == 0:
if NUM_STREAMS <= 1:
return None
if device in STREAMS:

View File

@ -148,15 +148,6 @@ class LowVramPatch:
else:
return out
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
convert_func = None
@ -278,9 +269,6 @@ class ModelPatcher:
if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
def model_size(self):
if self.size > 0:
return self.size
@ -674,16 +662,7 @@ class ModelPatcher:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
loading.append((module_offload_mem, module_mem, n, m, params))
loading.append((comfy.model_management.module_size(m), n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -697,22 +676,20 @@ class ModelPatcher:
load_completely = []
offloaded = []
offload_buffer = 0
loading.sort(reverse=True)
for x in loading:
module_offload_mem, module_mem, n, m, params = x
n = x[1]
m = x[2]
params = x[3]
module_mem = x[0]
lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
if not lowvram_fits:
offload_buffer = potential_offload
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
@ -746,11 +723,9 @@ class ModelPatcher:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
if full_load or lowvram_fits:
if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
else:
offload_buffer = potential_offload
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
@ -791,7 +766,7 @@ class ModelPatcher:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
@ -803,7 +778,6 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
self.model.model_offload_buffer_memory = offload_buffer
self.model.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
@ -857,7 +831,6 @@ class ModelPatcher:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
@ -876,14 +849,13 @@ class ModelPatcher:
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
if memory_to_free < memory_freed:
break
module_offload_mem, module_mem, n, m, params = unload
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
module_mem = unload[0]
n = unload[1]
m = unload[2]
params = unload[3]
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@ -934,18 +906,15 @@ class ModelPatcher:
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):