Improve logging and typing information for LoRA patches in ComfyUI

This commit is contained in:
doctorpangloss 2024-11-04 09:38:13 -08:00
parent 021d0d4f57
commit cde95eb71d
3 changed files with 70 additions and 29 deletions

View File

@ -24,6 +24,7 @@ import torch
from . import model_base from . import model_base
from . import model_management from . import model_management
from . import utils from . import utils
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
LORA_CLIP_MAP = { LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1", "mlp.fc1": "mlp_fc1",
@ -35,8 +36,8 @@ LORA_CLIP_MAP = {
} }
def load_lora(lora, to_load): def load_lora(lora, to_load) -> PatchDict:
patch_dict = {} patch_dict: PatchDict = {}
loaded_keys = set() loaded_keys = set()
for x in to_load: for x in to_load:
alpha_name = "{}.alpha".format(x) alpha_name = "{}.alpha".format(x)
@ -197,11 +198,13 @@ def load_lora(lora, to_load):
return patch_dict return patch_dict
def model_lora_keys_clip(model, key_map={}): def model_lora_keys_clip(model, key_map=None):
if key_map is None:
key_map = {}
sdk = model.state_dict().keys() sdk = model.state_dict().keys()
for k in sdk: for k in sdk:
if k.endswith(".weight"): if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names key_map["text_encoders.{}".format(k[:-len(".weight")])] = k # generic lora format without any weird key names
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False clip_l_present = False
@ -253,7 +256,7 @@ def model_lora_keys_clip(model, key_map={}):
if clip_l_present: if clip_l_present:
t5_index += 1 t5_index += 1
if t5_index == 2: if t5_index == 2:
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k # OneTrainer Flux
t5_index += 1 t5_index += 1
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
@ -275,7 +278,9 @@ def model_lora_keys_clip(model, key_map={}):
return key_map return key_map
def model_lora_keys_unet(model, key_map={}): def model_lora_keys_unet(model, key_map=None):
if key_map is None:
key_map = {}
sd = model.state_dict() sd = model.state_dict()
sdk = sd.keys() sdk = sd.keys()
@ -292,7 +297,7 @@ def model_lora_keys_unet(model, key_map={}):
unet_key = "diffusion_model.{}".format(diffusers_keys[k]) unet_key = "diffusion_model.{}".format(diffusers_keys[k])
key_lora = k[:-len(".weight")].replace(".", "_") key_lora = k[:-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = unet_key key_map["lora_unet_{}".format(key_lora)] = unet_key
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format key_map["lycoris_{}".format(key_lora)] = unet_key # simpletuner lycoris format
diffusers_lora_prefix = ["", "unet."] diffusers_lora_prefix = ["", "unet."]
for p in diffusers_lora_prefix: for p in diffusers_lora_prefix:
@ -315,10 +320,9 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
key_map[key_lora] = to key_map[key_lora] = to
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) # simpletuner lycoris format
key_map[key_lora] = to key_map[key_lora] = to
if isinstance(model, model_base.AuraFlow): # Diffusers lora AuraFlow if isinstance(model, model_base.AuraFlow): # Diffusers lora AuraFlow
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys: for k in diffusers_keys:
@ -340,7 +344,7 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k] to = diffusers_keys[k]
key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # onetrainer
return key_map return key_map
@ -400,13 +404,13 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_dtype=torch.float32):
for p in patches: for p in patches:
strength = p[0] strength = p[0]
v = p[1] v = p[1]
strength_model = p[2] strength_model = p[2]
offset = p[3] offset: PatchOffset = p[3]
function = p[4] function: PatchConversionFunction = p[4]
if function is None: if function is None:
function = lambda a: a function = lambda a: a
@ -419,9 +423,9 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
weight *= strength_model weight *= strength_model
if isinstance(v, list): if isinstance(v, list):
v = (calculate_weight(v[1:], v[0][1](model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), ) v = (calculate_weight(v[1:], v[0][1](model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype),)
patch_type = "" patch_type: PatchType = ""
if len(v) == 1: if len(v) == 1:
patch_type = "diff" patch_type = "diff"
elif len(v) == 2: elif len(v) == 2:
@ -574,7 +578,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try: try:
if old_glora: if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) # old lycoris glora
else: else:
if weight.dim() > 2: if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)

33
comfy/lora_types.py Normal file
View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import Literal, Any, NamedTuple, Protocol, Callable
import torch
PatchOffset = tuple[int, int, int]
PatchFunction = Any
PatchDictKey = str | tuple[str, PatchOffset] | tuple[str, PatchOffset, PatchFunction]
PatchType = Literal["lora", "loha", "lokr", "glora", "diff", ""]
PatchDictValue = tuple[PatchType, tuple]
PatchDict = dict[PatchDictKey, PatchDictValue]
class PatchConversionFunction(Protocol):
def __call__(self, tensor: torch.Tensor, **kwargs) -> torch.Tensor:
...
class PatchWeightTuple(NamedTuple):
weight: torch.Tensor
convert_func: PatchConversionFunction | Callable[[torch.Tensor], torch.Tensor]
class PatchTuple(NamedTuple):
strength_patch: float
patch: PatchDictValue
strength_model: float
offset: PatchOffset
function: PatchFunction
ModelPatchesDictValue = list[PatchTuple | PatchWeightTuple]

View File

@ -30,9 +30,12 @@ from . import model_management, lora
from . import utils from . import utils
from .comfy_types import UnetWrapperFunction from .comfy_types import UnetWrapperFunction
from .float import stochastic_rounding from .float import stochastic_rounding
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
from .model_base import BaseModel from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions
logger = logging.getLogger(__name__)
def string_to_seed(data): def string_to_seed(data):
crc = 0xFFFFFFFF crc = 0xFFFFFFFF
@ -134,7 +137,7 @@ class ModelPatcher(ModelManageable):
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None): def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size self.size = size
self.model: BaseModel | torch.nn.Module = model self.model: BaseModel | torch.nn.Module = model
self.patches = {} self.patches: dict[PatchDictKey, ModelPatchesDictValue] = {}
self.backup = {} self.backup = {}
self.object_patches = {} self.object_patches = {}
self.object_patches_backup = {} self.object_patches_backup = {}
@ -143,7 +146,7 @@ class ModelPatcher(ModelManageable):
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.patches_uuid = uuid.uuid4() self.patches_uuid: uuid.UUID = uuid.uuid4()
self.ckpt_name = ckpt_name self.ckpt_name = ckpt_name
self._memory_measurements = MemoryMeasurements(self.model) self._memory_measurements = MemoryMeasurements(self.model)
@ -202,7 +205,7 @@ class ModelPatcher(ModelManageable):
if self.patches_uuid == clone.patches_uuid: if self.patches_uuid == clone.patches_uuid:
if len(self.patches) != len(clone.patches): if len(self.patches) != len(clone.patches):
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.") logger.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
else: else:
return True return True
@ -316,14 +319,15 @@ class ModelPatcher(ModelManageable):
if hasattr(self.model, "get_dtype"): if hasattr(self.model, "get_dtype"):
return self.model.get_dtype() return self.model.get_dtype()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches: PatchDict, strength_patch=1.0, strength_model=1.0) -> list[PatchDictKey]:
p = set() p: set[PatchDictKey] = set()
model_sd = self.model.state_dict() model_sd = self.model.state_dict()
k: PatchDictKey
for k in patches: for k in patches:
offset = None offset = None
function = None function = None
if isinstance(k, str): if isinstance(k, str):
key = k key: str = k
else: else:
offset = k[1] offset = k[1]
key = k[0] key = k[0]
@ -333,7 +337,7 @@ class ModelPatcher(ModelManageable):
if key in model_sd: if key in model_sd:
p.add(k) p.add(k)
current_patches = self.patches.get(key, []) current_patches = self.patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function)) current_patches.append(PatchTuple(strength_patch, patches[k], strength_model, offset, function))
self.patches[key] = current_patches self.patches[key] = current_patches
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
@ -354,9 +358,9 @@ class ModelPatcher(ModelManageable):
convert_func = lambda a, **kwargs: a convert_func = lambda a, **kwargs: a
if k in self.patches: if k in self.patches:
p[k] = [(weight, convert_func)] + self.patches[k] p[k] = [PatchWeightTuple(weight, convert_func)] + self.patches[k]
else: else:
p[k] = [(weight, convert_func)] p[k] = [PatchWeightTuple(weight, convert_func)]
return p return p
def model_state_dict(self, filter_prefix=None): def model_state_dict(self, filter_prefix=None):
@ -460,17 +464,17 @@ class ModelPatcher(ModelManageable):
self.patch_weight_to_device(weight_key, device_to=device_to) self.patch_weight_to_device(weight_key, device_to=device_to)
self.patch_weight_to_device(bias_key, device_to=device_to) self.patch_weight_to_device(bias_key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) logger.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True m.comfy_patched_weights = True
for x in load_completely: for x in load_completely:
x[2].to(device_to) x[2].to(device_to)
if lowvram_counter > 0: if lowvram_counter > 0:
logging.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self._memory_measurements.model_lowvram = True self._memory_measurements.model_lowvram = True
else: else:
logging.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logger.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self._memory_measurements.model_lowvram = False self._memory_measurements.model_lowvram = False
if full_load: if full_load:
self.model.to(device_to) self.model.to(device_to)
@ -574,7 +578,7 @@ class ModelPatcher(ModelManageable):
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.comfy_patched_weights = False m.comfy_patched_weights = False
memory_freed += module_mem memory_freed += module_mem
logging.debug("freed {}".format(n)) logger.debug("freed {}".format(n))
self._memory_measurements.model_lowvram = True self._memory_measurements.model_lowvram = True
self._memory_measurements.lowvram_patch_counter += patch_counter self._memory_measurements.lowvram_patch_counter += patch_counter