Fix bypass dtype/device moving

This commit is contained in:
Kohaku-Blueleaf 2026-01-31 16:58:37 +08:00
parent c9b633d84f
commit 40c77373b6

View File

@ -21,6 +21,7 @@ from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection from comfy.patcher_extension import PatcherInjection
@ -181,18 +182,21 @@ class BypassForwardHook:
) )
return # Already injected return # Already injected
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward # Move adapter weights to compute device (GPU)
device = None # Use get_torch_device() instead of module.weight.device because
# with offloading, module weights may be on CPU while compute happens on GPU
device = comfy.model_management.get_torch_device()
# Get dtype from module weight if available
dtype = None dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None: if hasattr(self.module, "weight") and self.module.weight is not None:
device = self.module.weight.device
dtype = self.module.weight.dtype dtype = self.module.weight.dtype
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
device = self.module.W_q.device
dtype = self.module.W_q.dtype
if device is not None: # Only use dtype if it's a standard float type, not quantized
self._move_adapter_weights_to_device(device, dtype) if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = None
self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward self.original_forward = self.module.forward
self.module.forward = self._bypass_forward self.module.forward = self._bypass_forward