Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-10-09 11:22:05 +03:00 committed by GitHub
commit 21d9bef158
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 8 deletions

View File

@ -343,10 +343,10 @@ def model_lora_keys_unet(model, key_map={}):
return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_calc = weight + function(lora_diff).type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
@ -453,7 +453,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@ -499,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@ -536,7 +536,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@ -577,7 +577,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:

View File

@ -234,8 +234,12 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)