diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d19d6fe..ee56f8523 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -685,9 +685,9 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False): weight, set_func, convert_func = get_key_weight(self.model, key) - if key not in self.patches: + if key not in self.patches and not force_cast: return weight inplace_update = self.weight_inplace_update or inplace_update @@ -695,7 +695,7 @@ class ModelPatcher: if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) - temp_dtype = comfy.model_management.lora_compute_dtype(device_to) + temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None if device_to is not None: temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) else: @@ -703,9 +703,10 @@ class ModelPatcher: if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) - out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) + out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) + if key in self.patches: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) if return_weight: return out_weight elif inplace_update: @@ -1584,7 +1585,7 @@ class ModelPatcherDynamic(ModelPatcher): key = key_param_name_to_key(n, param_key) if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) - self.patch_weight_to_device(key, device_to=device_to) + self.patch_weight_to_device(key, device_to=device_to, force_cast=True) weight, _, _ = get_key_weight(self.model, key) if weight is not None: self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() @@ -1609,6 +1610,10 @@ class ModelPatcherDynamic(ModelPatcher): m._v = vbar.alloc(v_weight_size) allocated_size += v_weight_size + for param in params: + if param not in ("weight", "bias"): + force_load_param(self, param, device_to) + else: for param in params: key = key_param_name_to_key(n, param)