Disable torch compiler for cast_bias_weight function (#10384)

* Disable torch compiler for cast_bias_weight function

* Fix torch compile.
This commit is contained in:
comfyanonymous 2025-10-17 17:03:28 -07:00 committed by GitHub
parent 92d97380bd
commit 9da397ea2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,6 +25,9 @@ import comfy.rmsnorm
import contextlib import contextlib
def run_every_op(): def run_every_op():
if torch.compiler.is_compiling():
return
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
def scaled_dot_product_attention(q, k, v, *args, **kwargs): def scaled_dot_product_attention(q, k, v, *args, **kwargs):
@ -70,6 +73,7 @@ if torch.cuda.is_available() and torch.backends.cudnn.is_available() and Perform
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@torch.compiler.disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None: if input is not None:
if dtype is None: if dtype is None: