mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-05 11:10:16 +08:00
MPDynamic: Add support for model defined dtype
If the model defines a dtype that is different to what is in the state dict, respect that at load time. This is done as part of the casting process.
This commit is contained in:
parent
65b9729912
commit
2d96b2fdf1
@ -1,5 +1,34 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
class TensorGeometry(NamedTuple):
|
||||||
|
shape: any
|
||||||
|
dtype: torch.dtype
|
||||||
|
|
||||||
|
def element_size(self):
|
||||||
|
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
|
||||||
|
return info.bits // 8
|
||||||
|
|
||||||
|
def numel(self):
|
||||||
|
return math.prod(self.shape)
|
||||||
|
|
||||||
|
def tensors_to_geometries(tensors, dtype=None):
|
||||||
|
geometries = []
|
||||||
|
for t in tensors:
|
||||||
|
if t is None or isinstance(t, QuantizedTensor):
|
||||||
|
geometries.append(t)
|
||||||
|
continue
|
||||||
|
tdtype = t.dtype
|
||||||
|
if hasattr(t, "_model_dtype"):
|
||||||
|
tdtype = t._model_dtype
|
||||||
|
if dtype is not None:
|
||||||
|
tdtype = dtype
|
||||||
|
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
|
||||||
|
return geometries
|
||||||
|
|
||||||
def vram_aligned_size(tensor):
|
def vram_aligned_size(tensor):
|
||||||
if isinstance(tensor, list):
|
if isinstance(tensor, list):
|
||||||
return sum([vram_aligned_size(t) for t in tensor])
|
return sum([vram_aligned_size(t) for t in tensor])
|
||||||
|
|||||||
@ -1190,12 +1190,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
|||||||
assert r is None
|
assert r is None
|
||||||
assert stream is None
|
assert stream is None
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
r = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||||
if signature is not None:
|
if signature is not None:
|
||||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||||
v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0]
|
v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0]
|
||||||
|
|
||||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||||
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
|
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
|
||||||
|
|||||||
@ -1504,6 +1504,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
weight_function = []
|
weight_function = []
|
||||||
|
|
||||||
weight, _, _ = get_key_weight(self.model, key)
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
|
if weight is None:
|
||||||
|
return 0
|
||||||
if key in self.patches:
|
if key in self.patches:
|
||||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||||
num_patches += 1
|
num_patches += 1
|
||||||
@ -1513,7 +1515,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if key in self.weight_wrapper_patches:
|
if key in self.weight_wrapper_patches:
|
||||||
weight_function.extend(self.weight_wrapper_patches[key])
|
weight_function.extend(self.weight_wrapper_patches[key])
|
||||||
setattr(m, param_key + "_function", weight_function)
|
setattr(m, param_key + "_function", weight_function)
|
||||||
return comfy.memory_management.vram_aligned_size(weight)
|
geometry = weight
|
||||||
|
if not isinstance(weight, QuantizedTensor):
|
||||||
|
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
|
||||||
|
weight._model_dtype = model_dtype
|
||||||
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
|
return comfy.memory_management.vram_aligned_size(geometry)
|
||||||
|
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@ -1535,9 +1542,13 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
weight, _, _ = get_key_weight(self.model, key)
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
weight.seed_key = key
|
weight.seed_key = key
|
||||||
set_dirty(weight, dirty)
|
set_dirty(weight, dirty)
|
||||||
weight_size = weight.numel() * weight.element_size()
|
geometry = weight
|
||||||
|
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
|
||||||
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
|
weight_size = geometry.numel() * geometry.element_size()
|
||||||
if vbar is not None and not hasattr(weight, "_v"):
|
if vbar is not None and not hasattr(weight, "_v"):
|
||||||
weight._v = vbar.alloc(weight_size)
|
weight._v = vbar.alloc(weight_size)
|
||||||
|
weight._model_dtype = model_dtype
|
||||||
allocated_size += weight_size
|
allocated_size += weight_size
|
||||||
|
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||||
|
|||||||
21
comfy/ops.py
21
comfy/ops.py
@ -81,6 +81,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||||
if signature is not None:
|
if signature is not None:
|
||||||
@ -88,6 +89,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||||
|
|
||||||
if not resident:
|
if not resident:
|
||||||
|
cast_dest = None
|
||||||
|
|
||||||
xfer_source = [ s.weight, s.bias ]
|
xfer_source = [ s.weight, s.bias ]
|
||||||
|
|
||||||
@ -95,6 +97,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
if pin is not None:
|
if pin is not None:
|
||||||
xfer_source = [ pin ]
|
xfer_source = [ pin ]
|
||||||
resident = True #If pinned data exists, it always has LowVram already applied
|
resident = True #If pinned data exists, it always has LowVram already applied
|
||||||
|
else:
|
||||||
|
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
if data.dtype != geometry.dtype:
|
||||||
|
cast_dest = xfer_dest
|
||||||
|
if cast_dest is None:
|
||||||
|
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||||
|
xfer_dest = None
|
||||||
|
break
|
||||||
|
|
||||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
@ -111,6 +123,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
if cast_dest is not None:
|
||||||
|
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest),
|
||||||
|
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||||
|
if post_cast is not None:
|
||||||
|
post_cast.copy_(pre_cast)
|
||||||
|
xfer_dest = cast_dest
|
||||||
|
|
||||||
pin = None
|
pin = None
|
||||||
if signature is not None:
|
if signature is not None:
|
||||||
#If we are able to increase our load level (e.g. user reduces resolution or batch number)
|
#If we are able to increase our load level (e.g. user reduces resolution or batch number)
|
||||||
@ -122,7 +141,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
comfy.pinned_memory.pin_memory(s)
|
comfy.pinned_memory.pin_memory(s)
|
||||||
pin = comfy.pinned_memory.get_pin(s)
|
pin = comfy.pinned_memory.get_pin(s)
|
||||||
|
|
||||||
params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest)
|
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||||
weight = params[0]
|
weight = params[0]
|
||||||
bias = params[1]
|
bias = params[1]
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ def pin_memory(module):
|
|||||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||||
return
|
return
|
||||||
#FIXME: This is a RAM cache trigger event
|
#FIXME: This is a RAM cache trigger event
|
||||||
params = [ module.weight, module.bias ]
|
params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ])
|
||||||
size = comfy.memory_management.vram_aligned_size(params)
|
size = comfy.memory_management.vram_aligned_size(params)
|
||||||
pin = torch.empty((size,), dtype=torch.uint8)
|
pin = torch.empty((size,), dtype=torch.uint8)
|
||||||
if comfy.model_management.pin_memory(pin):
|
if comfy.model_management.pin_memory(pin):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user