diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 1d058bece..d8354cd43 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -337,6 +337,16 @@ def generic_copy_(func, args, kwargs): return qt_dest return func(*args, **kwargs) +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) def generic_has_compatible_shallow_copy_type(func, args, kwargs): @@ -383,10 +393,11 @@ class TensorCoreFP8Layout(QuantizedLayout): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) + tensor_fp32 = tensor.to(torch.float32) + tensor_scaled = tensor_fp32 * (1.0 / scale) # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' - # lp_amax = torch.finfo(dtype).max - # torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) layout_params = {