diff --git a/comfy/ops.py b/comfy/ops.py index cce68febe..dc06709a1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -626,6 +626,20 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec assert inplace_update is False # TODO: eventually remove the inplace_update stuff self.weight = torch.nn.Parameter(weight, requires_grad=False) + def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working + if recurse: + for module in self.children(): + module._apply(fn) + + for key, param in self._parameters.items(): + if param is None: + continue + self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + return self + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):