mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Improve logging and typing information for LoRA patches in ComfyUI
This commit is contained in:
parent
021d0d4f57
commit
cde95eb71d
@ -24,6 +24,7 @@ import torch
|
||||
from . import model_base
|
||||
from . import model_management
|
||||
from . import utils
|
||||
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
@ -35,8 +36,8 @@ LORA_CLIP_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load):
|
||||
patch_dict = {}
|
||||
def load_lora(lora, to_load) -> PatchDict:
|
||||
patch_dict: PatchDict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
alpha_name = "{}.alpha".format(x)
|
||||
@ -197,11 +198,13 @@ def load_lora(lora, to_load):
|
||||
return patch_dict
|
||||
|
||||
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
def model_lora_keys_clip(model, key_map=None):
|
||||
if key_map is None:
|
||||
key_map = {}
|
||||
sdk = model.state_dict().keys()
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k # generic lora format without any weird key names
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
@ -253,7 +256,7 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
if clip_l_present:
|
||||
t5_index += 1
|
||||
if t5_index == 2:
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k # OneTrainer Flux
|
||||
t5_index += 1
|
||||
|
||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||
@ -275,7 +278,9 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
return key_map
|
||||
|
||||
|
||||
def model_lora_keys_unet(model, key_map={}):
|
||||
def model_lora_keys_unet(model, key_map=None):
|
||||
if key_map is None:
|
||||
key_map = {}
|
||||
sd = model.state_dict()
|
||||
sdk = sd.keys()
|
||||
|
||||
@ -292,7 +297,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
||||
key_map["lycoris_{}".format(key_lora)] = unet_key # simpletuner lycoris format
|
||||
|
||||
diffusers_lora_prefix = ["", "unet."]
|
||||
for p in diffusers_lora_prefix:
|
||||
@ -315,10 +320,9 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
||||
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) # simpletuner lycoris format
|
||||
key_map[key_lora] = to
|
||||
|
||||
|
||||
if isinstance(model, model_base.AuraFlow): # Diffusers lora AuraFlow
|
||||
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
@ -340,7 +344,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
to = diffusers_keys[k]
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # onetrainer
|
||||
|
||||
return key_map
|
||||
|
||||
@ -400,13 +404,13 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
return padded_tensor
|
||||
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_dtype=torch.float32):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
offset = p[3]
|
||||
function = p[4]
|
||||
offset: PatchOffset = p[3]
|
||||
function: PatchConversionFunction = p[4]
|
||||
if function is None:
|
||||
function = lambda a: a
|
||||
|
||||
@ -419,9 +423,9 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], v[0][1](model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
||||
v = (calculate_weight(v[1:], v[0][1](model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype),)
|
||||
|
||||
patch_type = ""
|
||||
patch_type: PatchType = ""
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
@ -574,7 +578,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
|
||||
try:
|
||||
if old_glora:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) # old lycoris glora
|
||||
else:
|
||||
if weight.dim() > 2:
|
||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||
|
||||
33
comfy/lora_types.py
Normal file
33
comfy/lora_types.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, Any, NamedTuple, Protocol, Callable
|
||||
|
||||
import torch
|
||||
|
||||
PatchOffset = tuple[int, int, int]
|
||||
PatchFunction = Any
|
||||
PatchDictKey = str | tuple[str, PatchOffset] | tuple[str, PatchOffset, PatchFunction]
|
||||
PatchType = Literal["lora", "loha", "lokr", "glora", "diff", ""]
|
||||
PatchDictValue = tuple[PatchType, tuple]
|
||||
PatchDict = dict[PatchDictKey, PatchDictValue]
|
||||
|
||||
|
||||
class PatchConversionFunction(Protocol):
|
||||
def __call__(self, tensor: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
class PatchWeightTuple(NamedTuple):
|
||||
weight: torch.Tensor
|
||||
convert_func: PatchConversionFunction | Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
class PatchTuple(NamedTuple):
|
||||
strength_patch: float
|
||||
patch: PatchDictValue
|
||||
strength_model: float
|
||||
offset: PatchOffset
|
||||
function: PatchFunction
|
||||
|
||||
|
||||
ModelPatchesDictValue = list[PatchTuple | PatchWeightTuple]
|
||||
@ -30,9 +30,12 @@ from . import model_management, lora
|
||||
from . import utils
|
||||
from .comfy_types import UnetWrapperFunction
|
||||
from .float import stochastic_rounding
|
||||
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
|
||||
from .model_base import BaseModel
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
@ -134,7 +137,7 @@ class ModelPatcher(ModelManageable):
|
||||
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
self.size = size
|
||||
self.model: BaseModel | torch.nn.Module = model
|
||||
self.patches = {}
|
||||
self.patches: dict[PatchDictKey, ModelPatchesDictValue] = {}
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
@ -143,7 +146,7 @@ class ModelPatcher(ModelManageable):
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.patches_uuid: uuid.UUID = uuid.uuid4()
|
||||
self.ckpt_name = ckpt_name
|
||||
self._memory_measurements = MemoryMeasurements(self.model)
|
||||
|
||||
@ -202,7 +205,7 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
if self.patches_uuid == clone.patches_uuid:
|
||||
if len(self.patches) != len(clone.patches):
|
||||
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
||||
logger.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
||||
else:
|
||||
return True
|
||||
|
||||
@ -316,14 +319,15 @@ class ModelPatcher(ModelManageable):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = set()
|
||||
def add_patches(self, patches: PatchDict, strength_patch=1.0, strength_model=1.0) -> list[PatchDictKey]:
|
||||
p: set[PatchDictKey] = set()
|
||||
model_sd = self.model.state_dict()
|
||||
k: PatchDictKey
|
||||
for k in patches:
|
||||
offset = None
|
||||
function = None
|
||||
if isinstance(k, str):
|
||||
key = k
|
||||
key: str = k
|
||||
else:
|
||||
offset = k[1]
|
||||
key = k[0]
|
||||
@ -333,7 +337,7 @@ class ModelPatcher(ModelManageable):
|
||||
if key in model_sd:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(key, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||
current_patches.append(PatchTuple(strength_patch, patches[k], strength_model, offset, function))
|
||||
self.patches[key] = current_patches
|
||||
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
@ -354,9 +358,9 @@ class ModelPatcher(ModelManageable):
|
||||
convert_func = lambda a, **kwargs: a
|
||||
|
||||
if k in self.patches:
|
||||
p[k] = [(weight, convert_func)] + self.patches[k]
|
||||
p[k] = [PatchWeightTuple(weight, convert_func)] + self.patches[k]
|
||||
else:
|
||||
p[k] = [(weight, convert_func)]
|
||||
p[k] = [PatchWeightTuple(weight, convert_func)]
|
||||
return p
|
||||
|
||||
def model_state_dict(self, filter_prefix=None):
|
||||
@ -460,17 +464,17 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
logger.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
|
||||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
self._memory_measurements.model_lowvram = True
|
||||
else:
|
||||
logging.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logger.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
self._memory_measurements.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
@ -574,7 +578,7 @@ class ModelPatcher(ModelManageable):
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
logger.debug("freed {}".format(n))
|
||||
|
||||
self._memory_measurements.model_lowvram = True
|
||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||
|
||||
Loading…
Reference in New Issue
Block a user