mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-12 18:22:53 +08:00
model_patcher: Fix safetensors saving of fp8 (#13835)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This was missing proper weight scale casting in the saving path.
This commit is contained in:
parent
428c323780
commit
20e439419c
@ -242,6 +242,37 @@ class LazyCastingParam(torch.nn.Parameter):
|
|||||||
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
|
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
class LazyCastingQuantizedParam:
|
||||||
|
def __init__(self, model, key):
|
||||||
|
self.model = model
|
||||||
|
self.key = key
|
||||||
|
self.cpu_state_dict = None
|
||||||
|
|
||||||
|
def state_dict_tensor(self, state_dict_key):
|
||||||
|
if self.cpu_state_dict is None:
|
||||||
|
weight = self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True)
|
||||||
|
self.cpu_state_dict = {k: v.to("cpu") for k, v in weight.state_dict(self.key).items()}
|
||||||
|
return self.cpu_state_dict[state_dict_key]
|
||||||
|
|
||||||
|
|
||||||
|
class LazyCastingParamPiece(torch.nn.Parameter):
|
||||||
|
def __new__(cls, caster, state_dict_key, tensor):
|
||||||
|
return super().__new__(cls, tensor)
|
||||||
|
|
||||||
|
def __init__(self, caster, state_dict_key, tensor):
|
||||||
|
self.caster = caster
|
||||||
|
self.state_dict_key = state_dict_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return CustomTorchDevice
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
caster = self.caster
|
||||||
|
del self.caster
|
||||||
|
return caster.state_dict_tensor(self.state_dict_key)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -1463,20 +1494,37 @@ class ModelPatcher:
|
|||||||
self.clear_cached_hook_weights()
|
self.clear_cached_hook_weights()
|
||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||||
unet_state_dict = self.model.diffusion_model.state_dict()
|
original_state_dict = self.model.diffusion_model.state_dict()
|
||||||
for k, v in unet_state_dict.items():
|
unet_state_dict = {}
|
||||||
|
keys = list(original_state_dict)
|
||||||
|
while len(keys) > 0:
|
||||||
|
k = keys.pop(0)
|
||||||
|
v = original_state_dict[k]
|
||||||
op_keys = k.rsplit('.', 1)
|
op_keys = k.rsplit('.', 1)
|
||||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||||
|
unet_state_dict[k] = v
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||||
except:
|
except:
|
||||||
|
unet_state_dict[k] = v
|
||||||
continue
|
continue
|
||||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||||
|
unet_state_dict[k] = v
|
||||||
continue
|
continue
|
||||||
key = "diffusion_model." + k
|
key = "diffusion_model." + k
|
||||||
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
weight = comfy.utils.get_attr(self.model, key)
|
||||||
|
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
|
||||||
|
qt_state_dict = weight.state_dict(k)
|
||||||
|
caster = LazyCastingQuantizedParam(self, key)
|
||||||
|
for group_key in (x for x in qt_state_dict if x in original_state_dict):
|
||||||
|
if group_key in keys:
|
||||||
|
keys.remove(group_key)
|
||||||
|
unet_state_dict.pop(group_key, "")
|
||||||
|
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
|
||||||
|
continue
|
||||||
|
unet_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user