fix: Correction: Modify the dtype conversion logic within the cast_to function.

When the dtype is `None` or `float8`, it should be uniformly converted to `float16` to prevent type mismatches during the conversion of `float8` types.
This commit is contained in:
Kiruno-lz 2026-03-06 16:17:21 +08:00
parent 5f01662146
commit b7cbe8afa5

View File

@ -1221,7 +1221,8 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if weight.dtype in FLOAT8_TYPES:
weight = weight.float()
weight = weight.to(target_device, dtype=torch.float16)
dtype = torch.float16 if dtype is None else dtype
if dtype is None or dtype in FLOAT8_TYPES:
dtype = torch.float16
if device is None or weight.device == device:
if not copy: