diff --git a/comfy/lora.py b/comfy/lora.py index e9bfa68f3..295f7d1d8 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -24,6 +24,7 @@ import torch from . import model_base from . import model_management from . import utils +from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue LORA_CLIP_MAP = { "mlp.fc1": "mlp_fc1", @@ -35,8 +36,8 @@ LORA_CLIP_MAP = { } -def load_lora(lora, to_load): - patch_dict = {} +def load_lora(lora, to_load) -> PatchDict: + patch_dict: PatchDict = {} loaded_keys = set() for x in to_load: alpha_name = "{}.alpha".format(x) @@ -197,11 +198,13 @@ def load_lora(lora, to_load): return patch_dict -def model_lora_keys_clip(model, key_map={}): +def model_lora_keys_clip(model, key_map=None): + if key_map is None: + key_map = {} sdk = model.state_dict().keys() for k in sdk: if k.endswith(".weight"): - key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + key_map["text_encoders.{}".format(k[:-len(".weight")])] = k # generic lora format without any weird key names text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False @@ -253,7 +256,7 @@ def model_lora_keys_clip(model, key_map={}): if clip_l_present: t5_index += 1 if t5_index == 2: - key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux + key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k # OneTrainer Flux t5_index += 1 key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k @@ -275,7 +278,9 @@ def model_lora_keys_clip(model, key_map={}): return key_map -def model_lora_keys_unet(model, key_map={}): +def model_lora_keys_unet(model, key_map=None): + if key_map is None: + key_map = {} sd = model.state_dict() sdk = sd.keys() @@ -292,7 +297,7 @@ def model_lora_keys_unet(model, key_map={}): unet_key = "diffusion_model.{}".format(diffusers_keys[k]) key_lora = k[:-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = unet_key - key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format + key_map["lycoris_{}".format(key_lora)] = unet_key # simpletuner lycoris format diffusers_lora_prefix = ["", "unet."] for p in diffusers_lora_prefix: @@ -315,10 +320,9 @@ def model_lora_keys_unet(model, key_map={}): key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora key_map[key_lora] = to - key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format + key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) # simpletuner lycoris format key_map[key_lora] = to - if isinstance(model, model_base.AuraFlow): # Diffusers lora AuraFlow diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: @@ -340,7 +344,7 @@ def model_lora_keys_unet(model, key_map={}): to = diffusers_keys[k] key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris - key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer + key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # onetrainer return key_map @@ -400,13 +404,13 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten return padded_tensor -def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): +def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_dtype=torch.float32): for p in patches: strength = p[0] v = p[1] strength_model = p[2] - offset = p[3] - function = p[4] + offset: PatchOffset = p[3] + function: PatchConversionFunction = p[4] if function is None: function = lambda a: a @@ -419,9 +423,9 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): weight *= strength_model if isinstance(v, list): - v = (calculate_weight(v[1:], v[0][1](model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), ) + v = (calculate_weight(v[1:], v[0][1](model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype),) - patch_type = "" + patch_type: PatchType = "" if len(v) == 1: patch_type = "diff" elif len(v) == 2: @@ -574,7 +578,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): try: if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) # old lycoris glora else: if weight.dim() > 2: lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) diff --git a/comfy/lora_types.py b/comfy/lora_types.py new file mode 100644 index 000000000..a54645a60 --- /dev/null +++ b/comfy/lora_types.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Literal, Any, NamedTuple, Protocol, Callable + +import torch + +PatchOffset = tuple[int, int, int] +PatchFunction = Any +PatchDictKey = str | tuple[str, PatchOffset] | tuple[str, PatchOffset, PatchFunction] +PatchType = Literal["lora", "loha", "lokr", "glora", "diff", ""] +PatchDictValue = tuple[PatchType, tuple] +PatchDict = dict[PatchDictKey, PatchDictValue] + + +class PatchConversionFunction(Protocol): + def __call__(self, tensor: torch.Tensor, **kwargs) -> torch.Tensor: + ... + + +class PatchWeightTuple(NamedTuple): + weight: torch.Tensor + convert_func: PatchConversionFunction | Callable[[torch.Tensor], torch.Tensor] + + +class PatchTuple(NamedTuple): + strength_patch: float + patch: PatchDictValue + strength_model: float + offset: PatchOffset + function: PatchFunction + + +ModelPatchesDictValue = list[PatchTuple | PatchWeightTuple] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 52a5b9264..a58a9e3d3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -30,9 +30,12 @@ from . import model_management, lora from . import utils from .comfy_types import UnetWrapperFunction from .float import stochastic_rounding +from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue from .model_base import BaseModel from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions +logger = logging.getLogger(__name__) + def string_to_seed(data): crc = 0xFFFFFFFF @@ -134,7 +137,7 @@ class ModelPatcher(ModelManageable): def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None): self.size = size self.model: BaseModel | torch.nn.Module = model - self.patches = {} + self.patches: dict[PatchDictKey, ModelPatchesDictValue] = {} self.backup = {} self.object_patches = {} self.object_patches_backup = {} @@ -143,7 +146,7 @@ class ModelPatcher(ModelManageable): self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update - self.patches_uuid = uuid.uuid4() + self.patches_uuid: uuid.UUID = uuid.uuid4() self.ckpt_name = ckpt_name self._memory_measurements = MemoryMeasurements(self.model) @@ -202,7 +205,7 @@ class ModelPatcher(ModelManageable): if self.patches_uuid == clone.patches_uuid: if len(self.patches) != len(clone.patches): - logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.") + logger.warning("WARNING: something went wrong, same patch uuid but different length of patches.") else: return True @@ -316,14 +319,15 @@ class ModelPatcher(ModelManageable): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() + def add_patches(self, patches: PatchDict, strength_patch=1.0, strength_model=1.0) -> list[PatchDictKey]: + p: set[PatchDictKey] = set() model_sd = self.model.state_dict() + k: PatchDictKey for k in patches: offset = None function = None if isinstance(k, str): - key = k + key: str = k else: offset = k[1] key = k[0] @@ -333,7 +337,7 @@ class ModelPatcher(ModelManageable): if key in model_sd: p.add(k) current_patches = self.patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + current_patches.append(PatchTuple(strength_patch, patches[k], strength_model, offset, function)) self.patches[key] = current_patches self.patches_uuid = uuid.uuid4() @@ -354,9 +358,9 @@ class ModelPatcher(ModelManageable): convert_func = lambda a, **kwargs: a if k in self.patches: - p[k] = [(weight, convert_func)] + self.patches[k] + p[k] = [PatchWeightTuple(weight, convert_func)] + self.patches[k] else: - p[k] = [(weight, convert_func)] + p[k] = [PatchWeightTuple(weight, convert_func)] return p def model_state_dict(self, filter_prefix=None): @@ -460,17 +464,17 @@ class ModelPatcher(ModelManageable): self.patch_weight_to_device(weight_key, device_to=device_to) self.patch_weight_to_device(bias_key, device_to=device_to) - logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + logger.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True for x in load_completely: x[2].to(device_to) if lowvram_counter > 0: - logging.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) self._memory_measurements.model_lowvram = True else: - logging.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logger.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) self._memory_measurements.model_lowvram = False if full_load: self.model.to(device_to) @@ -574,7 +578,7 @@ class ModelPatcher(ModelManageable): m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem - logging.debug("freed {}".format(n)) + logger.debug("freed {}".format(n)) self._memory_measurements.model_lowvram = True self._memory_measurements.lowvram_patch_counter += patch_counter