diff --git a/comfy/ops.py b/comfy/ops.py index 640622fd1..93b228ce4 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -599,6 +599,7 @@ class MixedPrecisionOps(disable_weight_init): 'block_size': qconfig.get("group_size", None), } if layout_params['scale'] is not None: + layout_params['scale'] = layout_params['scale'].to(device=device) manually_loaded_keys.append(weight_scale_key) self.weight = torch.nn.Parameter( @@ -611,7 +612,7 @@ class MixedPrecisionOps(disable_weight_init): _v = state_dict.pop(param_key, None) if _v is None: continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + self.register_buffer(param_name, _v.to(device=device)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)