mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 02:23:06 +08:00
torch compile fixes.
This commit is contained in:
parent
295a0170d6
commit
b8afb60ee8
@ -252,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn):
|
|||||||
|
|
||||||
|
|
||||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||||
if target_dtype is not None and target_dtype != qt.dtype:
|
|
||||||
logging.warning(
|
|
||||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
|
||||||
f"but not supported for quantized tensors. Ignoring dtype."
|
|
||||||
)
|
|
||||||
|
|
||||||
if target_layout is not None and target_layout != torch.strided:
|
if target_layout is not None and target_layout != torch.strided:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||||
@ -277,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
|||||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||||
new_q_data = qt._qdata.to(device=target_device)
|
new_q_data = qt._qdata.to(device=target_device)
|
||||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
if target_dtype is not None:
|
||||||
|
new_params["orig_dtype"] = target_dtype
|
||||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||||
return new_qt
|
return new_qt
|
||||||
@ -400,7 +396,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||||
orig_dtype = tensor.dtype
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
if scale == "recalculate":
|
if isinstance(scale, str) and scale == "recalculate":
|
||||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||||
|
|
||||||
if scale is not None:
|
if scale is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user