mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-03 10:10:20 +08:00
Fix bypass dtype/device moving
This commit is contained in:
parent
c9b633d84f
commit
40c77373b6
@ -21,6 +21,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||
from comfy.patcher_extension import PatcherInjection
|
||||
|
||||
@ -181,18 +182,21 @@ class BypassForwardHook:
|
||||
)
|
||||
return # Already injected
|
||||
|
||||
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
||||
device = None
|
||||
# Move adapter weights to compute device (GPU)
|
||||
# Use get_torch_device() instead of module.weight.device because
|
||||
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
# Get dtype from module weight if available
|
||||
dtype = None
|
||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||
device = self.module.weight.device
|
||||
dtype = self.module.weight.dtype
|
||||
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
||||
device = self.module.W_q.device
|
||||
dtype = self.module.W_q.dtype
|
||||
|
||||
if device is not None:
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
# Only use dtype if it's a standard float type, not quantized
|
||||
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||
dtype = None
|
||||
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
|
||||
self.original_forward = self.module.forward
|
||||
self.module.forward = self._bypass_forward
|
||||
|
||||
Loading…
Reference in New Issue
Block a user