Fix torch compile issue.

This commit is contained in:
comfyanonymous 2025-12-04 23:58:29 -05:00 committed by GitHub
parent b8afb60ee8
commit 88172a4339
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -626,6 +626,20 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
assert inplace_update is False # TODO: eventually remove the inplace_update stuff assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False) self.weight = torch.nn.Parameter(weight, requires_grad=False)
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
if recurse:
for module in self.children():
module._apply(fn)
for key, param in self._parameters.items():
if param is None:
continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
return MixedPrecisionOps return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):