mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 08:12:34 +08:00
Merge d72908617c into 2d4970ff67
This commit is contained in:
commit
dd05437a9c
@ -55,6 +55,12 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
# MPS workaround: if we're on MPS and trying to create float8, force CPU
|
||||
if value.device.type == "mps":
|
||||
cpu_value = value.to("cpu")
|
||||
# Recursive call on CPU
|
||||
return stochastic_rounding(cpu_value, dtype, seed)
|
||||
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
|
||||
@ -52,11 +52,6 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
float8_types = []
|
||||
try:
|
||||
@ -1287,6 +1282,24 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
target_device = device
|
||||
if target_device is not None and is_device_mps(target_device):
|
||||
is_quantized = hasattr(weight, "storage_dtype")
|
||||
real_dtype = weight.storage_dtype if is_quantized else weight.dtype
|
||||
if real_dtype in FLOAT8_TYPES:
|
||||
if is_quantized:
|
||||
# MPS workaround: ensure everything is on CPU before dequantizing
|
||||
# This prevents 'Trying to convert Float8_e4m3fn to the MPS backend' errors
|
||||
# if the tensor or its scale are partially on MPS.
|
||||
if weight.device.type != "cpu":
|
||||
weight = weight.to("cpu")
|
||||
weight = weight.dequantize()
|
||||
if weight.dtype in FLOAT8_TYPES:
|
||||
weight = weight.float()
|
||||
weight = weight.to(target_device, dtype=torch.float16)
|
||||
if dtype is None or dtype in FLOAT8_TYPES:
|
||||
dtype = torch.float16
|
||||
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user