diff --git a/comfy/utils.py b/comfy/utils.py index e331b618b..13b7ca6c8 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -899,7 +899,7 @@ def set_attr(obj, attr, value): def set_attr_param(obj, attr, value): # Clone inference tensors (created under torch.inference_mode) since # their version counter is frozen and nn.Parameter() cannot wrap them. - if value.is_inference(): + if (not torch.is_inference_mode_enabled()) and value.is_inference(): value = value.clone() return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))