diff --git a/comfy/model_management.py b/comfy/model_management.py index c5fb1eea4..acae7b6e3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -981,6 +981,9 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if not copy: if dtype is None or weight.dtype == dtype: return weight + if stream is not None: + with stream: + return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy) if stream is not None: