mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
Merge 0eff43261b into be518db5a7
This commit is contained in:
commit
8c1914a7eb
@ -25,11 +25,11 @@ class AudioEncoderModel():
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
DynamicVRAM = "dynamic_vram"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
@ -257,3 +258,6 @@ elif args.fast == []:
|
||||
# '--fast' is provided with a list of performance features, use that list
|
||||
else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||
|
||||
@ -47,10 +47,10 @@ class ClipVisionModel():
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -203,7 +203,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
54
comfy/memory_management.py
Normal file
54
comfy/memory_management.py
Normal file
@ -0,0 +1,54 @@
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
import comfy_aimdo.torch
|
||||
|
||||
def vram_aligned_size(tensor):
|
||||
if isinstance(tensor, list):
|
||||
return sum([vram_aligned_size(t) for t in tensor])
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, _ = tensor.__tensor_flatten__()
|
||||
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
||||
|
||||
if tensor is None:
|
||||
return 0
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
aligment_req = 1024
|
||||
return (size + aligment_req - 1) // aligment_req * aligment_req
|
||||
|
||||
def interpret_gathered_like(tensors, gathered):
|
||||
offset = 0
|
||||
dest_views = []
|
||||
|
||||
if gathered.dim() != 1 or gathered.element_size() != 1:
|
||||
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
||||
|
||||
for tensor in tensors:
|
||||
|
||||
if tensor is None:
|
||||
dest_views.append(None)
|
||||
continue
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
||||
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
||||
else:
|
||||
templates = { "data": tensor }
|
||||
|
||||
actuals = {}
|
||||
for attr, template in templates.items():
|
||||
size = template.numel() * template.element_size()
|
||||
if offset + size > gathered.numel():
|
||||
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
||||
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
||||
offset += vram_aligned_size(template)
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
||||
else:
|
||||
dest_views.append(actuals["data"])
|
||||
|
||||
return dest_views
|
||||
|
||||
aimdo_allocator = comfy_aimdo.torch.CUDAPluggableAllocator()
|
||||
@ -298,7 +298,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
def load_model_weights(self, sd, unet_prefix="", assign=False):
|
||||
to_load = {}
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
@ -306,7 +306,7 @@ class BaseModel(torch.nn.Module):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
@ -321,7 +321,7 @@ class BaseModel(torch.nn.Module):
|
||||
def process_latent_out(self, latent):
|
||||
return self.latent_format.process_out(latent)
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
extra_sds = []
|
||||
if clip_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
||||
@ -329,10 +329,7 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
||||
if clip_vision_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
unet_state_dict["v_pred"] = torch.tensor([])
|
||||
|
||||
@ -775,8 +772,8 @@ class StableAudio1(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
||||
for k in d:
|
||||
s = d[k]
|
||||
|
||||
@ -26,6 +26,12 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
|
||||
import comfy_aimdo.torch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -592,7 +598,7 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
@ -607,15 +613,22 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = None
|
||||
memory_to_free = 1e32
|
||||
ram_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
continue
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
unloaded_model.append(i)
|
||||
if ram_to_free > 0:
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
@ -650,7 +663,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
models_to_load = []
|
||||
|
||||
free_for_dynamic=True
|
||||
for x in models:
|
||||
if not x.is_dynamic():
|
||||
free_for_dynamic = False
|
||||
loaded_model = LoadedModel(x)
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
@ -676,19 +692,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||
#want to do.
|
||||
#FIXME: This should subtract off the to_load current pin consumption.
|
||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic)
|
||||
logging.info("{} models unloaded.".format(len(models_l)))
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
@ -732,6 +754,9 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
@ -1051,6 +1076,49 @@ def current_stream(device):
|
||||
return None
|
||||
|
||||
stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = nullcontext()
|
||||
|
||||
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None or cast_buffer.numel() < size:
|
||||
if ref is LARGEST_CASTED_WEIGHT[0]:
|
||||
#If there is one giant weight we do not want both streams to
|
||||
#allocate a buffer for it. It's up to the caster to get the other
|
||||
#offload stream in this corner case
|
||||
return None
|
||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||
torch.cuda.synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
torch.cuda.empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
if size > LARGEST_CASTED_WEIGHT[1]:
|
||||
LARGEST_CASTED_WEIGHT = (ref, size)
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
torch.cuda.synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS == 0:
|
||||
@ -1093,7 +1161,59 @@ def sync_stream(device, stream):
|
||||
return
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
wf_context = nullcontext()
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||
with wf_context:
|
||||
for tensor in tensors:
|
||||
dest_view = dest_views.pop(0)
|
||||
if tensor is None:
|
||||
continue
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
if hasattr(weight, "_v"):
|
||||
#Unexpected usage patterns. There is no reason these don't work but they
|
||||
#have no testing and no callers do this.
|
||||
assert r is None
|
||||
assert stream is None
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0]
|
||||
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
|
||||
#a non comfy weight
|
||||
r.copy_(v_tensor)
|
||||
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
|
||||
return r
|
||||
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
|
||||
#FIXME: remove hooks before PR
|
||||
if hasattr(weight, "comfy_hook"):
|
||||
dtype = r.dtype
|
||||
r = weight.comfy_hook(r)
|
||||
if r.dtype != dtype:
|
||||
r = comfy.float.stochastic_rounding(r, dtype, seed=comfy.utils.string_to_seed(weight.seed_key))
|
||||
|
||||
if signature is not None:
|
||||
v_tensor.copy_(r)
|
||||
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
|
||||
|
||||
return r
|
||||
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
@ -1112,10 +1232,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
@ -1135,7 +1257,7 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
@ -1557,6 +1679,7 @@ def soft_empty_cache(force=False):
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
@ -38,19 +38,7 @@ from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
@ -212,6 +200,27 @@ class MemoryCounter:
|
||||
def decrement(self, used: int):
|
||||
self.value -= used
|
||||
|
||||
CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0)
|
||||
|
||||
class LazyCastingParam(torch.nn.Parameter):
|
||||
def __new__(cls, model, key, tensor):
|
||||
return super().__new__(cls, tensor)
|
||||
|
||||
def __init__(self, model, key, tensor):
|
||||
self.model = model
|
||||
self.key = key
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return CustomTorchDevice
|
||||
|
||||
#safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is
|
||||
#then just a short lived thing in the safetensors serialization logic inside its big for loop over
|
||||
#all weights getting garbage collected per-weight
|
||||
def to(self, *args, **kwargs):
|
||||
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@ -269,6 +278,9 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
def is_dynamic(self):
|
||||
return False
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@ -284,6 +296,9 @@ class ModelPatcher:
|
||||
def lowvram_patch_counter(self):
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def get_free_memory(self, device):
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
@ -611,14 +626,14 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if key not in self.patches:
|
||||
return weight
|
||||
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
if key not in self.backup and not return_weight:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
@ -631,13 +646,15 @@ class ModelPatcher:
|
||||
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
if return_weight:
|
||||
return out_weight
|
||||
elif inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight)
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
@ -654,7 +671,7 @@ class ModelPatcher:
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self):
|
||||
def _load_list(self, prio_comfy_cast_weights=False):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
params = []
|
||||
@ -681,7 +698,8 @@ class ModelPatcher:
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
|
||||
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@ -984,6 +1002,9 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
self.model_patches_to(self.offload_device)
|
||||
@ -1317,10 +1338,10 @@ class ModelPatcher:
|
||||
key, original_weights=original_weights)
|
||||
del original_weights[key]
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
||||
set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key))
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: disable caching if not enough system RAM to do so
|
||||
target_device = self.offload_device
|
||||
@ -1355,7 +1376,237 @@ class ModelPatcher:
|
||||
self.unpatch_hooks()
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
unet_state_dict = self.model.diffusion_model.state_dict()
|
||||
for k, v in unet_state_dict.items():
|
||||
op_keys = k.rsplit('.', 1)
|
||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||
continue
|
||||
try:
|
||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||
except:
|
||||
continue
|
||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
||||
return self.model.state_dict_for_saving(unet_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def __new__(cls, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
if comfy.model_management.is_device_cpu(load_device):
|
||||
#reroute to default MP for CPUs
|
||||
return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
||||
#and non-dynamic patchers.
|
||||
if hasattr(self.model, "model_loaded_weight_memory"):
|
||||
del self.model.model_loaded_weight_memory
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
assert load_device is not None
|
||||
|
||||
def is_dynamic(self):
|
||||
return True
|
||||
|
||||
def _vbar_get(self, create=False):
|
||||
if self.load_device == torch.device("cpu"):
|
||||
return None
|
||||
vbar = self.model.dynamic_vbars.get(self.load_device, None)
|
||||
if create and vbar is None:
|
||||
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 1.2, self.load_device.index)
|
||||
self.model.dynamic_vbars[self.load_device] = vbar
|
||||
return vbar
|
||||
|
||||
def loaded_size(self):
|
||||
vbar = self._vbar_get()
|
||||
if vbar is None:
|
||||
return 0
|
||||
return vbar.loaded_size()
|
||||
|
||||
def get_free_memory(self, device):
|
||||
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||
#all non-dynamic models so this is safe even if its not 100% true that this
|
||||
#would all be avaiable for inference use.
|
||||
return comfy.model_management.get_total_memory(device) - self.model_size()
|
||||
|
||||
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading")
|
||||
|
||||
def unpin_weight(self, key):
|
||||
raise RuntimeError("unpin_weight invalid for dymamic weight loading")
|
||||
|
||||
def unpin_all_weights(self):
|
||||
pass
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
#Pad this significantly. We are trying to get away from precise estimates. This
|
||||
#estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you
|
||||
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
|
||||
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
|
||||
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
|
||||
|
||||
#Force patching doesn't make sense in Dynamic loading, as you dont know what does and
|
||||
#doesn't need to be forced at this stage. The only thing you could do would be patch
|
||||
#it all on CPU which consumes huge RAM.
|
||||
assert not force_patch_weights
|
||||
|
||||
#Full load doesn't make sense as we dont actually have any loader capability here and
|
||||
#now.
|
||||
assert not full_load
|
||||
|
||||
assert device_to == self.load_device
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
|
||||
vbar = self._vbar_get(create=True)
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
#We have way more tools for acceleration on comfy weight offloading, so always
|
||||
#prioritize the non-comfy weights (note the order reverse).
|
||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||
loading.sort(reverse=True)
|
||||
|
||||
for x in loading:
|
||||
_, _, _, n, m, params = x
|
||||
|
||||
def set_dirty(item, dirty):
|
||||
if dirty or not hasattr(item, "_v_signature"):
|
||||
item._v_signature = None
|
||||
if dirty:
|
||||
comfy.pinned_memory.unpin_memory(item)
|
||||
|
||||
def setup_param(self, m, n, param_key):
|
||||
nonlocal num_patches
|
||||
key = "{}.{}".format(n, param_key)
|
||||
|
||||
weight_function = []
|
||||
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if key in self.patches:
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
num_patches += 1
|
||||
else:
|
||||
setattr(m, param_key + "_lowvram_function", None)
|
||||
|
||||
if key in self.weight_wrapper_patches:
|
||||
weight_function.extend(self.weight_wrapper_patches[key])
|
||||
setattr(m, param_key + "_function", weight_function)
|
||||
return comfy.memory_management.vram_aligned_size(weight)
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
m.pin_failed = False
|
||||
m.seed_key = n
|
||||
set_dirty(m, dirty)
|
||||
|
||||
v_weight_size = 0
|
||||
v_weight_size += setup_param(self, m, n, "weight")
|
||||
v_weight_size += setup_param(self, m, n, "bias")
|
||||
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
else:
|
||||
for param in params:
|
||||
key = "{}.{}".format(n, param)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
weight.seed_key = key
|
||||
set_dirty(weight, dirty)
|
||||
weight_size = weight.numel() * weight.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
allocated_size += weight_size
|
||||
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
#These are all super dangerous. Who knows what the custom nodes actually do here...
|
||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
assert self.load_device != torch.device("cpu")
|
||||
|
||||
vbar = self._vbar_get()
|
||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
#This isn't used by the core at all and can only be to load a model out of
|
||||
#the control of proper model_managment. If you are a custom node author reading
|
||||
#this, the correct pattern is to call load_models_gpu() to get a proper
|
||||
#managed load of your model.
|
||||
assert not load_weights
|
||||
return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
super().unpatch_model(device_to=None, unpatch_weights=False)
|
||||
|
||||
if unpatch_weights:
|
||||
self.partially_unload_ram(1e32)
|
||||
self.partially_unload(None)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid)
|
||||
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=False)
|
||||
self.patch_model(load_weights=False)
|
||||
|
||||
try:
|
||||
self.load(device_to, dirty=dirty)
|
||||
except Exception as e:
|
||||
self.detach()
|
||||
raise e
|
||||
#ModelPatcher::partially_load returns a number on what got loaded but
|
||||
#nothing in core uses this and we have no data in the Dynamic world. Hit
|
||||
#the custom node devs with a None rather than a 0 that would mislead any
|
||||
#logic they might have.
|
||||
return None
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||
assert False #Should be unreachable - we dont ever cache in the new implementation
|
||||
|
||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||
if key not in combined_patches:
|
||||
return
|
||||
|
||||
raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup")
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
pass
|
||||
|
||||
CoreModelPatcher = ModelPatcher
|
||||
|
||||
162
comfy/ops.py
162
comfy/ops.py
@ -23,6 +23,12 @@ from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
import comfy.memory_management
|
||||
import comfy.pinned_memory
|
||||
import comfy.utils
|
||||
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy_aimdo.torch
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
@ -72,7 +78,109 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
if signature is not None:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
|
||||
if not resident:
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
if pin is not None:
|
||||
xfer_source = [ pin ]
|
||||
resident = True #If pinned data exists, it always has LowVram already applied
|
||||
|
||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if xfer_dest is None and offload_stream is not None:
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||
offload_stream = None
|
||||
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
pin = None
|
||||
if signature is not None:
|
||||
#If we are able to increase our load level (e.g. user reduces resolution or batch number)
|
||||
#reclaim the pin previously used for offload.
|
||||
comfy.pinned_memory.unpin_memory(s)
|
||||
elif not resident:
|
||||
#prepare a new pin
|
||||
assert comfy.pinned_memory.get_pin(s) is None
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
|
||||
orig = x
|
||||
|
||||
def to_dequant(tensor, dtype):
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
tensor = tensor.dequantize()
|
||||
return tensor
|
||||
|
||||
if orig.dtype != dtype or len(fns) > 0:
|
||||
x = to_dequant(x, dtype)
|
||||
if not resident and lowvram_fn is not None:
|
||||
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
else:
|
||||
y = x
|
||||
if update_weight:
|
||||
orig.copy_(y)
|
||||
for f in fns:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
update_weight = signature is not None or pin is not None
|
||||
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
s._v_signature=signature
|
||||
|
||||
if pin is not None:
|
||||
xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0]
|
||||
#FIXME: This might be the wrong thing to do. Some reading suggests the DMA engine
|
||||
#is posted writes and the compute stream could just fire and forget here. That
|
||||
#would save this sync and some stalling on the offload stream that is better off
|
||||
#running ahead to the next layer to read.
|
||||
if offload_stream is not None:
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=offload_stream, r=pin)
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
@ -87,22 +195,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
bias = None
|
||||
weight = None
|
||||
|
||||
if offload_stream is not None and not args.cuda_malloc:
|
||||
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
|
||||
if cast_buffer is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
@ -110,6 +234,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
weight_a = weight
|
||||
|
||||
if s.bias is not None:
|
||||
bias = bias.to(dtype=bias_dtype)
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
@ -131,14 +256,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
os, weight_a, bias_a = offload_stream
|
||||
device=None
|
||||
#FIXME: This is not good RTTI
|
||||
if not isinstance(weight_a, torch.Tensor):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
device = weight_a
|
||||
if os is None:
|
||||
return
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
if device is None:
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
os.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
|
||||
@ -653,8 +784,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@ -664,6 +795,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
input_shape = input.shape
|
||||
reshaped_3d = False
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
@ -682,7 +815,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
30
comfy/pinned_memory.py
Normal file
30
comfy/pinned_memory.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def get_pin(module):
|
||||
return getattr(module, "_pin", None)
|
||||
|
||||
def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
params = [ module.weight, module.bias ]
|
||||
size = comfy.memory_management.vram_aligned_size(params)
|
||||
pin = torch.empty((size,), dtype=torch.uint8)
|
||||
if comfy.model_management.pin_memory(pin):
|
||||
module._pin = pin
|
||||
else:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
comfy.model_management.unpin_memory(module._pin)
|
||||
del module._pin
|
||||
return size
|
||||
@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
from functools import partial
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.sampler_helpers
|
||||
@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
free_memory = model.current_patcher.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
|
||||
49
comfy/sd.py
49
comfy/sd.py
@ -128,7 +128,7 @@ class CLIP:
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
@ -288,7 +288,7 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
@ -665,13 +665,6 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("Leftover VAE keys {}".format(u))
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
@ -682,7 +675,18 @@ class VAE:
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
mp = comfy.model_patcher.CoreModelPatcher
|
||||
if self.disable_offload:
|
||||
mp = comfy.model_patcher.ModelPatcher
|
||||
self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("Leftover VAE keys {}".format(u))
|
||||
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
self.model_size()
|
||||
|
||||
@ -797,7 +801,7 @@ class VAE:
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
@ -816,6 +820,7 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
torch.cuda.empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
@ -871,7 +876,7 @@ class VAE:
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
@ -891,6 +896,7 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
torch.cuda.empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
@ -1315,7 +1321,7 @@ def load_gligen(ckpt_path):
|
||||
model = gligen.load_gligen(data)
|
||||
if model_management.should_use_fp16():
|
||||
model = model.half()
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def model_detection_error_hint(path, state_dict):
|
||||
filename = os.path.basename(path)
|
||||
@ -1403,7 +1409,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
@ -1446,7 +1453,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded diffusion model directly to GPU")
|
||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||
@ -1538,13 +1544,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
if not model_management.is_device_cpu(offload_device):
|
||||
model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
return model_patcher
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
@ -1575,9 +1582,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
model_management.load_models_gpu(load_models)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
|
||||
@ -1298,3 +1298,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
@ -258,9 +258,9 @@ class ModelPatchLoader:
|
||||
config['broken'] = True
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
return (model,)
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
model.load_state_dict(sd, assign=self.model_patcher.is_dynamic())
|
||||
return (model_patcher,)
|
||||
|
||||
|
||||
class DiffSynthCnetPatch:
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import os
|
||||
import importlib.util
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import subprocess
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||
def get_gpu_names():
|
||||
if os.name == 'nt':
|
||||
@ -85,8 +87,14 @@ if not args.cuda_malloc:
|
||||
except:
|
||||
pass
|
||||
|
||||
if enables_dynamic_vram() and comfy_aimdo.control.lib is not None:
|
||||
args.cuda_malloc = False
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""
|
||||
|
||||
if args.cuda_malloc and not args.disable_cuda_malloc:
|
||||
if args.disable_cuda_malloc:
|
||||
args.cuda_malloc = False
|
||||
|
||||
if args.cuda_malloc:
|
||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||
if env_var is None:
|
||||
env_var = "backend:cudaMallocAsync"
|
||||
|
||||
19
execution.py
19
execution.py
@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import copy
|
||||
import heapq
|
||||
import inspect
|
||||
@ -9,9 +10,11 @@ import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional, Union
|
||||
import asyncio
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
from latent_preview import set_preview_method
|
||||
import nodes
|
||||
@ -515,7 +518,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
|
||||
#Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows
|
||||
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
|
||||
#that we just want to cull out each model run.
|
||||
allocator = comfy.memory_management.aimdo_allocator
|
||||
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
torch.cuda.synchronize()
|
||||
if allocator is not None:
|
||||
#FIXME: this is probably a little zealous
|
||||
# Torch code comments says some stuff about not actually freeing tensors on mempool
|
||||
#context release. Explicitly garbage collect now.
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
|
||||
35
main.py
35
main.py
@ -5,7 +5,7 @@ import os
|
||||
import importlib.util
|
||||
import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.scanner import seed_assets
|
||||
import itertools
|
||||
@ -173,6 +173,30 @@ import gc
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
|
||||
has_aimdo = False
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
if comfy_aimdo.control.lib is not None:
|
||||
if args.verbose == 'DEBUG':
|
||||
comfy_aimdo.control.set_log_debug()
|
||||
elif args.verbose == 'CRITICAL':
|
||||
comfy_aimdo.control.set_log_critical()
|
||||
elif args.verbose == 'ERROR':
|
||||
comfy_aimdo.control.set_log_error()
|
||||
elif args.verbose == 'WARNING':
|
||||
comfy_aimdo.control.set_log_warning()
|
||||
else: #INFO
|
||||
comfy_aimdo.control.set_log_info()
|
||||
|
||||
if enables_dynamic_vram():
|
||||
logging.info("DynamicVRAM support detected and enabled")
|
||||
has_aimdo = True
|
||||
else:
|
||||
if enables_dynamic_vram():
|
||||
logging.info("No native comfy-aimdo install detected. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
@ -184,6 +208,15 @@ import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
|
||||
if has_aimdo:
|
||||
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
|
||||
comfy_aimdo.control.init_vram_guard(comfy.model_management.get_torch_device().index)
|
||||
else:
|
||||
comfy.memory_management.aimdo_allocator = None
|
||||
|
||||
def cuda_malloc_warning():
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
|
||||
@ -22,6 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.6
|
||||
comfy-aimdo>=0.1.1
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user