diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6deb71e12..020d723df 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -122,6 +122,11 @@ class LowVramPatch: self.set_func = set_func def __call__(self, weight): + if weight.dtype == torch.float8_e4m3fn or weight.dtype == torch.float8_e5m2: + temp_weight = weight.to(torch.bfloat16) + patched_weight = comfy.lora.calculate_weight(self.patches[self.key], temp_weight, self.key, intermediate_dtype=torch.bfloat16) + return patched_weight.to(weight.dtype) + return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2