porting the cast of the temp_weight to the gpu

This commit is contained in:
lpbb 2023-09-19 14:05:51 +02:00
parent 6d3dee9d16
commit 3e4755ed14
No known key found for this signature in database
GPG Key ID: 51521F624B0297B8

View File

@ -154,7 +154,7 @@ class ModelPatcher:
self.backup[key] = weight.to(self.offload_device)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
temp_weight = weight.to(device_to, copy=True).float()
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)