mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Merge f8af8dadaf into 35c1470935
This commit is contained in:
commit
f5f68e41c4
14
comfy/ops.py
14
comfy/ops.py
@ -1235,7 +1235,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
run_every_op()
|
||||
|
||||
input_shape = input.shape
|
||||
reshaped_3d = False
|
||||
reshaped_nd = False
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
@ -1272,12 +1272,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized and quantize_input:
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
# Reshape >=3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[-1]) if input.ndim >= 3 else input
|
||||
|
||||
# Fall back to non-quantized for non-2D tensors
|
||||
if input_reshaped.ndim == 2:
|
||||
reshaped_3d = input.ndim == 3
|
||||
reshaped_nd = input.ndim >= 3
|
||||
# dtype is now implicit in the layout class
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
@ -1292,9 +1292,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
weight_only_quant=weight_only_quant,
|
||||
)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
||||
# Reshape output back to original rank if input was >2D
|
||||
if reshaped_nd:
|
||||
output = output.reshape((*input_shape[:-1], self.weight.shape[0]))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user