mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 21:10:49 +08:00
Fix torch compile issue.
This commit is contained in:
parent
b8afb60ee8
commit
88172a4339
14
comfy/ops.py
14
comfy/ops.py
@ -626,6 +626,20 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||
if recurse:
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user