mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Fix SageAttention crash after PR #10276 fp8 weight scaling changes
Problem:
After PR #10276 (commit 139addd5) introduced convert_func/set_func for
proper fp8 weight scaling during LoRA application, users with SageAttention
enabled experience 100% reproducible crashes (Exception 0xC0000005
ACCESS_VIOLATION) during KSampler execution.
Root Cause:
PR #10276 added fp8 weight transformations (scale up -> apply LoRA -> scale
down) to fix LoRA quality with Wan 2.1/2.2 14B fp8 models. These
transformations:
1. Convert weights to float32 and create copies (new memory addresses)
2. Invalidate tensor metadata that SageAttention cached
3. Break SageAttention's internal memory references
4. Cause access violation when SageAttention tries to use old pointers
SageAttention expects weights at original memory addresses without
transformations between caching and usage.
Solution:
Add conditional bypass in LowVramPatch.__call__ to detect when
SageAttention is active (via --use-sage-attention flag) and skip
convert_func/set_func calls. This preserves SageAttention's memory
reference stability while maintaining PR #10276 benefits for users
without SageAttention.
Trade-offs:
- When SageAttention is enabled with fp8 models + LoRAs, LoRAs are
applied to scaled weights instead of properly scaled weights
- Potential quality impact unknown (no issues observed in testing)
- Only affects users who explicitly enable SageAttention flag
- Users without SageAttention continue to benefit from PR #10276
Testing Completed:
- RTX 5090, CUDA 12.8, PyTorch 2.7.0, SageAttention 2.1.1
- Wan 2.2 fp8 models with multiple LoRAs
- Crash eliminated, ~40% SageAttention performance benefit preserved
- No visual quality degradation observed
- Non-SageAttention workflows unaffected
Testing Requested:
- Other GPU architectures (RTX 4090, 3090, etc.)
- Different CUDA/PyTorch version combinations
- fp8 LoRA quality comparison with SageAttention enabled
- Edge cases: mixed fp8/non-fp8 workflows
Files Changed:
- comfy/model_patcher.py: LowVramPatch.__call__ method
Related:
- Issue: SageAttention incompatibility with fp8 weight scaling
- Original PR: #10276 (fp8 LoRA quality fix for Wan models)
- SageAttention: https://github.com/thu-ml/SageAttention
This commit is contained in:
parent
a125cd84b0
commit
8c374c8b90
@ -130,9 +130,22 @@ class LowVramPatch:
|
||||
self.set_func = set_func
|
||||
|
||||
def __call__(self, weight):
|
||||
# Detect SageAttention and skip conversion for compatibility
|
||||
sage_attention_active = False
|
||||
try:
|
||||
import comfy.cli_args
|
||||
sage_attention_active = hasattr(comfy.cli_args.args, 'use_sage_attention') and \
|
||||
comfy.cli_args.args.use_sage_attention
|
||||
except:
|
||||
pass
|
||||
|
||||
intermediate_dtype = weight.dtype
|
||||
if self.convert_func is not None:
|
||||
|
||||
# Skip convert_func when SageAttention is active (compatibility mode)
|
||||
if self.convert_func is not None and not sage_attention_active:
|
||||
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
|
||||
elif sage_attention_active and self.convert_func is not None:
|
||||
logging.debug(f"Skipping convert_func for {self.key} (SageAttention compatibility)")
|
||||
|
||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||
intermediate_dtype = torch.float32
|
||||
@ -140,10 +153,16 @@ class LowVramPatch:
|
||||
if self.set_func is None:
|
||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||
else:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||
# Skip set_func when SageAttention is active (compatibility mode)
|
||||
if not sage_attention_active:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||
else:
|
||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is not None:
|
||||
|
||||
# Skip set_func when SageAttention is active (compatibility mode)
|
||||
if self.set_func is not None and not sage_attention_active:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
||||
else:
|
||||
return out
|
||||
|
||||
Loading…
Reference in New Issue
Block a user