mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Improve logging and typing information for LoRA patches in ComfyUI
This commit is contained in:
parent
021d0d4f57
commit
cde95eb71d
@ -24,6 +24,7 @@ import torch
|
|||||||
from . import model_base
|
from . import model_base
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
LORA_CLIP_MAP = {
|
||||||
"mlp.fc1": "mlp_fc1",
|
"mlp.fc1": "mlp_fc1",
|
||||||
@ -35,8 +36,8 @@ LORA_CLIP_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_lora(lora, to_load):
|
def load_lora(lora, to_load) -> PatchDict:
|
||||||
patch_dict = {}
|
patch_dict: PatchDict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
alpha_name = "{}.alpha".format(x)
|
alpha_name = "{}.alpha".format(x)
|
||||||
@ -197,11 +198,13 @@ def load_lora(lora, to_load):
|
|||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map=None):
|
||||||
|
if key_map is None:
|
||||||
|
key_map = {}
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k # generic lora format without any weird key names
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
@ -253,7 +256,7 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
if clip_l_present:
|
if clip_l_present:
|
||||||
t5_index += 1
|
t5_index += 1
|
||||||
if t5_index == 2:
|
if t5_index == 2:
|
||||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k # OneTrainer Flux
|
||||||
t5_index += 1
|
t5_index += 1
|
||||||
|
|
||||||
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
||||||
@ -275,7 +278,9 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def model_lora_keys_unet(model, key_map={}):
|
def model_lora_keys_unet(model, key_map=None):
|
||||||
|
if key_map is None:
|
||||||
|
key_map = {}
|
||||||
sd = model.state_dict()
|
sd = model.state_dict()
|
||||||
sdk = sd.keys()
|
sdk = sd.keys()
|
||||||
|
|
||||||
@ -292,7 +297,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
||||||
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
key_map["lycoris_{}".format(key_lora)] = unet_key # simpletuner lycoris format
|
||||||
|
|
||||||
diffusers_lora_prefix = ["", "unet."]
|
diffusers_lora_prefix = ["", "unet."]
|
||||||
for p in diffusers_lora_prefix:
|
for p in diffusers_lora_prefix:
|
||||||
@ -315,10 +320,9 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
|
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) # simpletuner lycoris format
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
|
||||||
if isinstance(model, model_base.AuraFlow): # Diffusers lora AuraFlow
|
if isinstance(model, model_base.AuraFlow): # Diffusers lora AuraFlow
|
||||||
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
@ -340,7 +344,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
|
||||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # onetrainer
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
@ -400,13 +404,13 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
|||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_dtype=torch.float32):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
v = p[1]
|
v = p[1]
|
||||||
strength_model = p[2]
|
strength_model = p[2]
|
||||||
offset = p[3]
|
offset: PatchOffset = p[3]
|
||||||
function = p[4]
|
function: PatchConversionFunction = p[4]
|
||||||
if function is None:
|
if function is None:
|
||||||
function = lambda a: a
|
function = lambda a: a
|
||||||
|
|
||||||
@ -419,9 +423,9 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
|
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (calculate_weight(v[1:], v[0][1](model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
v = (calculate_weight(v[1:], v[0][1](model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype),)
|
||||||
|
|
||||||
patch_type = ""
|
patch_type: PatchType = ""
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
patch_type = "diff"
|
patch_type = "diff"
|
||||||
elif len(v) == 2:
|
elif len(v) == 2:
|
||||||
@ -574,7 +578,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if old_glora:
|
if old_glora:
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) # old lycoris glora
|
||||||
else:
|
else:
|
||||||
if weight.dim() > 2:
|
if weight.dim() > 2:
|
||||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
|||||||
33
comfy/lora_types.py
Normal file
33
comfy/lora_types.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal, Any, NamedTuple, Protocol, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
PatchOffset = tuple[int, int, int]
|
||||||
|
PatchFunction = Any
|
||||||
|
PatchDictKey = str | tuple[str, PatchOffset] | tuple[str, PatchOffset, PatchFunction]
|
||||||
|
PatchType = Literal["lora", "loha", "lokr", "glora", "diff", ""]
|
||||||
|
PatchDictValue = tuple[PatchType, tuple]
|
||||||
|
PatchDict = dict[PatchDictKey, PatchDictValue]
|
||||||
|
|
||||||
|
|
||||||
|
class PatchConversionFunction(Protocol):
|
||||||
|
def __call__(self, tensor: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class PatchWeightTuple(NamedTuple):
|
||||||
|
weight: torch.Tensor
|
||||||
|
convert_func: PatchConversionFunction | Callable[[torch.Tensor], torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class PatchTuple(NamedTuple):
|
||||||
|
strength_patch: float
|
||||||
|
patch: PatchDictValue
|
||||||
|
strength_model: float
|
||||||
|
offset: PatchOffset
|
||||||
|
function: PatchFunction
|
||||||
|
|
||||||
|
|
||||||
|
ModelPatchesDictValue = list[PatchTuple | PatchWeightTuple]
|
||||||
@ -30,9 +30,12 @@ from . import model_management, lora
|
|||||||
from . import utils
|
from . import utils
|
||||||
from .comfy_types import UnetWrapperFunction
|
from .comfy_types import UnetWrapperFunction
|
||||||
from .float import stochastic_rounding
|
from .float import stochastic_rounding
|
||||||
|
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
|
||||||
from .model_base import BaseModel
|
from .model_base import BaseModel
|
||||||
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions
|
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
@ -134,7 +137,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model: BaseModel | torch.nn.Module = model
|
self.model: BaseModel | torch.nn.Module = model
|
||||||
self.patches = {}
|
self.patches: dict[PatchDictKey, ModelPatchesDictValue] = {}
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.object_patches = {}
|
self.object_patches = {}
|
||||||
self.object_patches_backup = {}
|
self.object_patches_backup = {}
|
||||||
@ -143,7 +146,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid: uuid.UUID = uuid.uuid4()
|
||||||
self.ckpt_name = ckpt_name
|
self.ckpt_name = ckpt_name
|
||||||
self._memory_measurements = MemoryMeasurements(self.model)
|
self._memory_measurements = MemoryMeasurements(self.model)
|
||||||
|
|
||||||
@ -202,7 +205,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
|
|
||||||
if self.patches_uuid == clone.patches_uuid:
|
if self.patches_uuid == clone.patches_uuid:
|
||||||
if len(self.patches) != len(clone.patches):
|
if len(self.patches) != len(clone.patches):
|
||||||
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
logger.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -316,14 +319,15 @@ class ModelPatcher(ModelManageable):
|
|||||||
if hasattr(self.model, "get_dtype"):
|
if hasattr(self.model, "get_dtype"):
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches: PatchDict, strength_patch=1.0, strength_model=1.0) -> list[PatchDictKey]:
|
||||||
p = set()
|
p: set[PatchDictKey] = set()
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
|
k: PatchDictKey
|
||||||
for k in patches:
|
for k in patches:
|
||||||
offset = None
|
offset = None
|
||||||
function = None
|
function = None
|
||||||
if isinstance(k, str):
|
if isinstance(k, str):
|
||||||
key = k
|
key: str = k
|
||||||
else:
|
else:
|
||||||
offset = k[1]
|
offset = k[1]
|
||||||
key = k[0]
|
key = k[0]
|
||||||
@ -333,7 +337,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
if key in model_sd:
|
if key in model_sd:
|
||||||
p.add(k)
|
p.add(k)
|
||||||
current_patches = self.patches.get(key, [])
|
current_patches = self.patches.get(key, [])
|
||||||
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
current_patches.append(PatchTuple(strength_patch, patches[k], strength_model, offset, function))
|
||||||
self.patches[key] = current_patches
|
self.patches[key] = current_patches
|
||||||
|
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
@ -354,9 +358,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
convert_func = lambda a, **kwargs: a
|
convert_func = lambda a, **kwargs: a
|
||||||
|
|
||||||
if k in self.patches:
|
if k in self.patches:
|
||||||
p[k] = [(weight, convert_func)] + self.patches[k]
|
p[k] = [PatchWeightTuple(weight, convert_func)] + self.patches[k]
|
||||||
else:
|
else:
|
||||||
p[k] = [(weight, convert_func)]
|
p[k] = [PatchWeightTuple(weight, convert_func)]
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def model_state_dict(self, filter_prefix=None):
|
def model_state_dict(self, filter_prefix=None):
|
||||||
@ -460,17 +464,17 @@ class ModelPatcher(ModelManageable):
|
|||||||
|
|
||||||
self.patch_weight_to_device(weight_key, device_to=device_to)
|
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||||
self.patch_weight_to_device(bias_key, device_to=device_to)
|
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logger.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
|
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
x[2].to(device_to)
|
x[2].to(device_to)
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
self._memory_measurements.model_lowvram = True
|
self._memory_measurements.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logger.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self._memory_measurements.model_lowvram = False
|
self._memory_measurements.model_lowvram = False
|
||||||
if full_load:
|
if full_load:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
@ -574,7 +578,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logger.debug("freed {}".format(n))
|
||||||
|
|
||||||
self._memory_measurements.model_lowvram = True
|
self._memory_measurements.model_lowvram = True
|
||||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user