diff --git a/comfy/model_management.py b/comfy/model_management.py index 9145d06b1..de73d04ea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1221,7 +1221,8 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if weight.dtype in FLOAT8_TYPES: weight = weight.float() weight = weight.to(target_device, dtype=torch.float16) - dtype = torch.float16 if dtype is None else dtype + if dtype is None or dtype in FLOAT8_TYPES: + dtype = torch.float16 if device is None or weight.device == device: if not copy: