mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
add generic support for quantize tensor casting and proper scaling factor
This commit is contained in:
parent
db730ee283
commit
b95c05d095
@ -337,6 +337,16 @@ def generic_copy_(func, args, kwargs):
|
|||||||
return qt_dest
|
return qt_dest
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.to.dtype)
|
||||||
|
def generic_to_dtype(func, args, kwargs):
|
||||||
|
"""Handle .to(dtype) calls - dtype conversion only."""
|
||||||
|
src = args[0]
|
||||||
|
if isinstance(src, QuantizedTensor):
|
||||||
|
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
||||||
|
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
||||||
|
src._layout_params["orig_dtype"] = target_dtype
|
||||||
|
return src
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||||
@ -383,10 +393,11 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
scale = torch.tensor(scale)
|
scale = torch.tensor(scale)
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
tensor_fp32 = tensor.to(torch.float32)
|
||||||
|
tensor_scaled = tensor_fp32 * (1.0 / scale)
|
||||||
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
||||||
# lp_amax = torch.finfo(dtype).max
|
lp_amax = torch.finfo(dtype).max
|
||||||
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||||
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
layout_params = {
|
layout_params = {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user