Make sure empty 3d tensors get reshaped correctly

This commit is contained in:
Jedrzej Kosinski 2026-01-05 00:38:00 -08:00
parent cca52736d3
commit 101644cc42

View File

@ -437,7 +437,7 @@ def fp8_linear(self, input):
uncast_bias_weight(self, w, bias, offload_stream)
if tensor_3d:
o = o.reshape((-1, input_shape[1], w.shape[0]))
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
return o
@ -676,7 +676,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# Reshape output back to 3D if input was 3D
if tensor_3d:
output = output.reshape((-1, input_shape[1], self.weight.shape[0]))
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
return output