ops: Put weight cast on the offload stream

This needs to be on the offload stream. This reproduced a black screen
with low resolution images on a slow bus when using FP8.
This commit is contained in:
Rattus 2025-11-10 13:04:09 +10:00
parent dea899f221
commit a5c32e5b08

View File

@ -110,9 +110,9 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.bias_function:
bias = f(bias)
weight = weight.to(dtype=dtype)
if weight_has_function:
if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype)
for f in s.weight_function:
weight = f(weight)