mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 23:30:16 +08:00
Move comfy_extras nodes, fix pylint errors
This commit is contained in:
parent
db423f8013
commit
0ba08f273a
@ -64,7 +64,7 @@ def get_comfyui_version():
|
|||||||
comfyui_version = "unknown"
|
comfyui_version = "unknown"
|
||||||
repo_path = os.path.dirname(os.path.realpath(__file__))
|
repo_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
try:
|
try:
|
||||||
import pygit2
|
import pygit2 # pylint: disable=import-error
|
||||||
repo = pygit2.Repository(repo_path)
|
repo = pygit2.Repository(repo_path)
|
||||||
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@ -529,6 +529,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
|
rank = 0
|
||||||
dora_scale = v[5]
|
dora_scale = v[5]
|
||||||
|
|
||||||
old_glora = False
|
old_glora = False
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
|
||||||
def extract_lora(diff, rank):
|
def extract_lora(diff, rank):
|
||||||
conv2d = (len(diff.shape) == 4)
|
conv2d = (len(diff.shape) == 4)
|
||||||
kernel_size = None if not conv2d else diff.size()[2:4]
|
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||||
@ -20,7 +23,6 @@ def extract_lora(diff, rank):
|
|||||||
else:
|
else:
|
||||||
diff = diff.squeeze()
|
diff = diff.squeeze()
|
||||||
|
|
||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(diff.float())
|
U, S, Vh = torch.linalg.svd(diff.float())
|
||||||
U = U[:, :rank]
|
U = U[:, :rank]
|
||||||
S = S[:rank]
|
S = S[:rank]
|
||||||
@ -38,6 +40,7 @@ def extract_lora(diff, rank):
|
|||||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
return (U, Vh)
|
return (U, Vh)
|
||||||
|
|
||||||
|
|
||||||
class LoraSave:
|
class LoraSave:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
@ -45,10 +48,11 @@ class LoraSave:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||||
"rank": ("INT", {"default": 8, "min": 1, "max": 1024, "step": 1}),
|
"rank": ("INT", {"default": 8, "min": 1, "max": 1024, "step": 1}),
|
||||||
},
|
},
|
||||||
"optional": {"model_diff": ("MODEL",),},
|
"optional": {"model_diff": ("MODEL",), },
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
RETURN_TYPES = ()
|
||||||
FUNCTION = "save"
|
FUNCTION = "save"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@ -86,6 +90,7 @@ class LoraSave:
|
|||||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LoraSave": LoraSave
|
"LoraSave": LoraSave
|
||||||
}
|
}
|
||||||
@ -1,3 +1,6 @@
|
|||||||
|
import dataclasses
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||||
@ -8,8 +11,9 @@ except (ImportError, ModuleNotFoundError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
class Params4bit:
|
class Params4bit:
|
||||||
pass
|
data: Any
|
||||||
|
|
||||||
|
|
||||||
class QuantState:
|
class QuantState:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user