mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-24 05:20:48 +08:00
Fix torch compile issue.
This commit is contained in:
parent
b8afb60ee8
commit
88172a4339
14
comfy/ops.py
14
comfy/ops.py
@ -626,6 +626,20 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
|
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||||
|
if recurse:
|
||||||
|
for module in self.children():
|
||||||
|
module._apply(fn)
|
||||||
|
|
||||||
|
for key, param in self._parameters.items():
|
||||||
|
if param is None:
|
||||||
|
continue
|
||||||
|
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||||
|
for key, buf in self._buffers.items():
|
||||||
|
if buf is not None:
|
||||||
|
self._buffers[key] = fn(buf)
|
||||||
|
return self
|
||||||
|
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user