mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Disable torch compiler for cast_bias_weight function (#10384)
* Disable torch compiler for cast_bias_weight function * Fix torch compile.
This commit is contained in:
parent
92d97380bd
commit
9da397ea2f
@ -25,6 +25,9 @@ import comfy.rmsnorm
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
def run_every_op():
|
def run_every_op():
|
||||||
|
if torch.compiler.is_compiling():
|
||||||
|
return
|
||||||
|
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
@ -70,6 +73,7 @@ if torch.cuda.is_available() and torch.backends.cudnn.is_available() and Perform
|
|||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
@torch.compiler.disable()
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user