mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Fix bypass dtype/device moving
This commit is contained in:
parent
c9b633d84f
commit
40c77373b6
@ -21,6 +21,7 @@ from typing import Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||||
from comfy.patcher_extension import PatcherInjection
|
from comfy.patcher_extension import PatcherInjection
|
||||||
|
|
||||||
@ -181,18 +182,21 @@ class BypassForwardHook:
|
|||||||
)
|
)
|
||||||
return # Already injected
|
return # Already injected
|
||||||
|
|
||||||
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
# Move adapter weights to compute device (GPU)
|
||||||
device = None
|
# Use get_torch_device() instead of module.weight.device because
|
||||||
|
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
# Get dtype from module weight if available
|
||||||
dtype = None
|
dtype = None
|
||||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||||
device = self.module.weight.device
|
|
||||||
dtype = self.module.weight.dtype
|
dtype = self.module.weight.dtype
|
||||||
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
|
||||||
device = self.module.W_q.device
|
|
||||||
dtype = self.module.W_q.dtype
|
|
||||||
|
|
||||||
if device is not None:
|
# Only use dtype if it's a standard float type, not quantized
|
||||||
self._move_adapter_weights_to_device(device, dtype)
|
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||||
|
dtype = None
|
||||||
|
|
||||||
|
self._move_adapter_weights_to_device(device, dtype)
|
||||||
|
|
||||||
self.original_forward = self.module.forward
|
self.original_forward = self.module.forward
|
||||||
self.module.forward = self._bypass_forward
|
self.module.forward = self._bypass_forward
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user