mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-05 17:27:42 +08:00
MPDynamic: force load flux img_in weight (Fixes flux1 canny+depth lora crash) (#12446)
* lora: add weight shape calculations. This lets the loader know if a lora will change the shape of a weight so it can take appropriate action. * MPDynamic: force load flux img_in weight This weight is a bit special, in that the lora changes its geometry. This is rather unique, not handled by existing estimate and doesn't work for either offloading or dynamic_vram. Fix for dynamic_vram as a special case. Ideally we can fully precalculate these lora geometry changes at load time, but just get these models working first.
This commit is contained in:
parent
ecd2a19661
commit
c0370044cd
@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
|||||||
|
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
def calculate_shape(patches, weight, key, original_weights=None):
|
||||||
|
current_shape = weight.shape
|
||||||
|
|
||||||
|
for p in patches:
|
||||||
|
v = p[1]
|
||||||
|
offset = p[3]
|
||||||
|
|
||||||
|
# Offsets restore the old shape; lists force a diff without metadata
|
||||||
|
if offset is not None or isinstance(v, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
|
adapter_shape = v.calculate_shape(key)
|
||||||
|
if adapter_shape is not None:
|
||||||
|
current_shape = adapter_shape
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Standard diff logic with padding
|
||||||
|
if len(v) == 2:
|
||||||
|
patch_type, patch_data = v[0], v[1]
|
||||||
|
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
|
||||||
|
current_shape = patch_data[0].shape
|
||||||
|
|
||||||
|
return current_shape
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
|
|||||||
@ -1514,8 +1514,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
weight, _, _ = get_key_weight(self.model, key)
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return 0
|
return (False, 0)
|
||||||
if key in self.patches:
|
if key in self.patches:
|
||||||
|
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||||
|
return (True, 0)
|
||||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||||
num_patches += 1
|
num_patches += 1
|
||||||
else:
|
else:
|
||||||
@ -1529,7 +1531,13 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||||
weight._model_dtype = model_dtype
|
weight._model_dtype = model_dtype
|
||||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
return comfy.memory_management.vram_aligned_size(geometry)
|
return (False, comfy.memory_management.vram_aligned_size(geometry))
|
||||||
|
|
||||||
|
def force_load_param(self, param_key, device_to):
|
||||||
|
key = key_param_name_to_key(n, param_key)
|
||||||
|
if key in self.backup:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||||
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@ -1537,13 +1545,19 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
m.seed_key = n
|
m.seed_key = n
|
||||||
set_dirty(m, dirty)
|
set_dirty(m, dirty)
|
||||||
|
|
||||||
v_weight_size = 0
|
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||||
v_weight_size += setup_param(self, m, n, "weight")
|
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||||
v_weight_size += setup_param(self, m, n, "bias")
|
force_load = force_load or force_load_bias
|
||||||
|
v_weight_size += v_weight_bias
|
||||||
|
|
||||||
if vbar is not None and not hasattr(m, "_v"):
|
if force_load:
|
||||||
m._v = vbar.alloc(v_weight_size)
|
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||||
allocated_size += v_weight_size
|
force_load_param(self, "weight", device_to)
|
||||||
|
force_load_param(self, "bias", device_to)
|
||||||
|
else:
|
||||||
|
if vbar is not None and not hasattr(m, "_v"):
|
||||||
|
m._v = vbar.alloc(v_weight_size)
|
||||||
|
allocated_size += v_weight_size
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for param in params:
|
for param in params:
|
||||||
@ -1606,6 +1620,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
|
keys = list(self.backup.keys())
|
||||||
|
for k in keys:
|
||||||
|
bk = self.backup[k]
|
||||||
|
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||||
|
|||||||
@ -49,6 +49,12 @@ class WeightAdapterBase:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def calculate_shape(
|
||||||
|
self,
|
||||||
|
key
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
def calculate_weight(
|
def calculate_weight(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
|
|||||||
@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def calculate_shape(
|
||||||
|
self,
|
||||||
|
key
|
||||||
|
):
|
||||||
|
reshape = self.weights[5]
|
||||||
|
return tuple(reshape) if reshape is not None else None
|
||||||
|
|
||||||
def calculate_weight(
|
def calculate_weight(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user