mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-09 00:32:31 +08:00
Merge d72908617c into 2d4970ff67
This commit is contained in:
commit
dd05437a9c
@ -55,6 +55,12 @@ def stochastic_rounding(value, dtype, seed=0):
|
|||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
return value.to(dtype=torch.bfloat16)
|
return value.to(dtype=torch.bfloat16)
|
||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
|
# MPS workaround: if we're on MPS and trying to create float8, force CPU
|
||||||
|
if value.device.type == "mps":
|
||||||
|
cpu_value = value.to("cpu")
|
||||||
|
# Recursive call on CPU
|
||||||
|
return stochastic_rounding(cpu_value, dtype, seed)
|
||||||
|
|
||||||
generator = torch.Generator(device=value.device)
|
generator = torch.Generator(device=value.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
output = torch.empty_like(value, dtype=dtype)
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
|
|||||||
@ -52,11 +52,6 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
|
|
||||||
# Training Related State
|
|
||||||
in_training = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
float8_types = []
|
float8_types = []
|
||||||
try:
|
try:
|
||||||
@ -1287,6 +1282,24 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
|||||||
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||||
|
target_device = device
|
||||||
|
if target_device is not None and is_device_mps(target_device):
|
||||||
|
is_quantized = hasattr(weight, "storage_dtype")
|
||||||
|
real_dtype = weight.storage_dtype if is_quantized else weight.dtype
|
||||||
|
if real_dtype in FLOAT8_TYPES:
|
||||||
|
if is_quantized:
|
||||||
|
# MPS workaround: ensure everything is on CPU before dequantizing
|
||||||
|
# This prevents 'Trying to convert Float8_e4m3fn to the MPS backend' errors
|
||||||
|
# if the tensor or its scale are partially on MPS.
|
||||||
|
if weight.device.type != "cpu":
|
||||||
|
weight = weight.to("cpu")
|
||||||
|
weight = weight.dequantize()
|
||||||
|
if weight.dtype in FLOAT8_TYPES:
|
||||||
|
weight = weight.float()
|
||||||
|
weight = weight.to(target_device, dtype=torch.float16)
|
||||||
|
if dtype is None or dtype in FLOAT8_TYPES:
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
if not copy:
|
if not copy:
|
||||||
if dtype is None or weight.dtype == dtype:
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user