mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Merge 8c374c8b90 into 0c18842acb
This commit is contained in:
commit
7bdd0105c2
@ -130,9 +130,22 @@ class LowVramPatch:
|
||||
self.set_func = set_func
|
||||
|
||||
def __call__(self, weight):
|
||||
# Detect SageAttention and skip conversion for compatibility
|
||||
sage_attention_active = False
|
||||
try:
|
||||
import comfy.cli_args
|
||||
sage_attention_active = hasattr(comfy.cli_args.args, 'use_sage_attention') and \
|
||||
comfy.cli_args.args.use_sage_attention
|
||||
except:
|
||||
pass
|
||||
|
||||
intermediate_dtype = weight.dtype
|
||||
if self.convert_func is not None:
|
||||
|
||||
# Skip convert_func when SageAttention is active (compatibility mode)
|
||||
if self.convert_func is not None and not sage_attention_active:
|
||||
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
|
||||
elif sage_attention_active and self.convert_func is not None:
|
||||
logging.debug(f"Skipping convert_func for {self.key} (SageAttention compatibility)")
|
||||
|
||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||
intermediate_dtype = torch.float32
|
||||
@ -140,10 +153,16 @@ class LowVramPatch:
|
||||
if self.set_func is None:
|
||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||
else:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||
# Skip set_func when SageAttention is active (compatibility mode)
|
||||
if not sage_attention_active:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||
else:
|
||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is not None:
|
||||
|
||||
# Skip set_func when SageAttention is active (compatibility mode)
|
||||
if self.set_func is not None and not sage_attention_active:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
||||
else:
|
||||
return out
|
||||
|
||||
Loading…
Reference in New Issue
Block a user