Avoid inference/train tensor issue

This commit is contained in:
Kohaku-Blueleaf 2026-02-28 00:41:20 +08:00
parent eb33188c8e
commit 3e433cd02d

View File

@ -881,6 +881,10 @@ def set_attr(obj, attr, value):
return prev
def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def copy_to_param(obj, attr, value):