diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0fc38a9cf..a5f0cae45 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -87,6 +87,7 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto, help="Default preview method for sampler nodes.", action=EnumAction) + parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") cache_group = parser.add_mutually_exclusive_group() cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") attn_group = parser.add_mutually_exclusive_group() diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 70af63f6f..82ce9a212 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -110,6 +110,7 @@ class Configuration(dict): force_channels_last (bool): Force channels last format when inferencing the models. force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure. executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor + preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512. """ def __init__(self, **kwargs): @@ -184,6 +185,7 @@ class Configuration(dict): self.max_queue_size: int = 65536 self.force_channels_last: bool = False self.force_hf_local_dir_mode = False + self.preview_size: int = 512 # from guill self.cache_lru: int = 0 diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 659e66fa6..68e9767bb 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -100,6 +100,38 @@ def _cleanup(): class EmbeddedComfyClient: + """ + Embedded client for comfy executing prompts as a library. + + This client manages a single-threaded executor to run long-running or blocking tasks + asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor + in a dedicated thread for executing prompts and handling server-stub communications. + Example usage: + + Asynchronous (non-blocking) usage with async-await: + ``` + # Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your + # workspace. + prompt_dict = { + "1": {"class_type": "KSamplerAdvanced", ...} + ... + } + # Validate your workflow (the prompt) + from comfy.api.components.schema.prompt import Prompt + prompt = Prompt.validate(prompt_dict) + # Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor. + # It does not connect to a remote server. + async def main(): + async with EmbeddedComfyClient() as client: + outputs = await client.queue_prompt(prompt) + print(outputs) + print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI") + if __name__ == "__main__" + asyncio.run(main()) + ``` + + In order to use this in blocking methods, learn more about asyncio online. + """ def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None): self._progress_handler = progress_handler or ServerStub() self._executor = executor or ThreadPoolExecutor(max_workers=max_workers) diff --git a/comfy/cmd/latent_preview.py b/comfy/cmd/latent_preview.py index fa19408bb..8c2f00970 100644 --- a/comfy/cmd/latent_preview.py +++ b/comfy/cmd/latent_preview.py @@ -12,7 +12,7 @@ from .. import model_management from .. import utils import logging -MAX_PREVIEW_RESOLUTION = 512 +MAX_PREVIEW_RESOLUTION = args.preview_size def preview_to_image(latent_image): diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 6b53995cb..47fb8c1ab 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -200,6 +200,7 @@ async def main(from_script_dir: Optional[Path] = None): folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models")) + folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras")) if args.input_directory: input_dir = os.path.abspath(args.input_directory) diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 612a5d7bf..74e6b1bf1 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -6,9 +6,10 @@ from enum import Enum from typing import Optional, Literal, Protocol, Union, NamedTuple, List import PIL.Image -from typing_extensions import NotRequired, TypedDict, runtime_checkable +from typing_extensions import NotRequired, TypedDict from .queue_types import BinaryEventTypes +from ..cli_args_types import Configuration from ..nodes.package_typing import InputTypeSpec @@ -201,4 +202,7 @@ class Executor(Protocol): ... def shutdown(self, wait=True, *, cancel_futures=False): - ... \ No newline at end of file + ... + + +ExecutePromptArgs = tuple[dict, str, str, dict, ExecutorToClientProgress | None, Configuration | None] diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 631b22b62..835c85794 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -493,14 +493,15 @@ def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full return control -def load_controlnet_flux_xlabs(sd): +def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) - control_model = controlnet_flux.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_flux.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control + def load_controlnet_flux_instantx(sd): new_sd = model_detection.convert_diffusers_mmdit(sd, "") model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) @@ -520,6 +521,11 @@ def load_controlnet_flux_instantx(sd): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control + +def convert_mistoline(sd): + return utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) + + def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT @@ -580,13 +586,15 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]): if len(leftover_keys) > 0: logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd - elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format + elif "controlnet_blocks.0.weight" in controlnet_data: if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: - return load_controlnet_flux_xlabs(controlnet_data) + return load_controlnet_flux_xlabs_mistoline(controlnet_data) elif "pos_embed_input.proj.weight" in controlnet_data: - return load_controlnet_mmdit(controlnet_data) + return load_controlnet_mmdit(controlnet_data) # SD3 diffusers controlnet elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data) + elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux + return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True) pth_key = 'control_model.zero_convs.0.0.weight' pth = False diff --git a/comfy/distributed/process_pool_executor.py b/comfy/distributed/process_pool_executor.py index 6a198f47f..8f0e954e8 100644 --- a/comfy/distributed/process_pool_executor.py +++ b/comfy/distributed/process_pool_executor.py @@ -1,11 +1,16 @@ import concurrent.futures +from typing import Callable -from pebble import ProcessPool +from pebble import ProcessPool, ProcessFuture -from ..component_model.executor_types import Executor +from ..component_model.executor_types import Executor, ExecutePromptArgs class ProcessPoolExecutor(ProcessPool, Executor): + def __init__(self, max_workers: int = 1): + super().__init__(max_workers=1) + + def shutdown(self, wait=True, *, cancel_futures=False): if cancel_futures: raise NotImplementedError("cannot cancel futures in this implementation") @@ -15,5 +20,17 @@ class ProcessPoolExecutor(ProcessPool, Executor): self.stop() return + def schedule(self, function: Callable, + args: list = (), + kwargs: dict = {}, + timeout: float = None) -> ProcessFuture: + try: + args: ExecutePromptArgs + prompt, prompt_id, client_id, span_context, progress_handler, configuration = args + + except ValueError: + pass + super().schedule(function, args, kwargs, timeout) + def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: - return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None) \ No newline at end of file + return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None) diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index afe36d52f..501d807b1 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -1,32 +1,78 @@ -#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py +# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py +# modified to support different types of flux controlnets import math -from typing_extensions import Never import torch from einops import rearrange, repeat from torch import Tensor, nn +from typing_extensions import Never from .layers import (timestep_embedding) from .model import Flux from .. import common_dit +class MistolineCondDownsamplBlock(nn.Module): + def __init__(self, dtype=None, device=None, operations=None): + super().__init__() + self.encoder = nn.Sequential( + operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) + ) + + def forward(self, x): + return self.encoder(x) + + +class MistolineControlnetBlock(nn.Module): + def __init__(self, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.act = nn.SiLU() + + def forward(self, x): + return self.act(self.linear(x)) + + class ControlNetFlux(Flux): - def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) self.main_model_double = 19 self.main_model_single = 38 + + self.mistoline = mistoline # add ControlNet blocks + if self.mistoline: + control_block = lambda: MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + control_block = lambda: operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + self.controlnet_blocks = nn.ModuleList([]) for _ in range(self.params.depth): - controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) - self.controlnet_blocks.append(controlnet_block) + self.controlnet_blocks.append(control_block()) self.controlnet_single_blocks = nn.ModuleList([]) for _ in range(self.params.depth_single_blocks): - self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)) + self.controlnet_single_blocks.append(control_block()) self.num_union_modes = num_union_modes self.controlnet_mode_embedder = None @@ -37,44 +83,44 @@ class ControlNetFlux(Flux): self.latent_input = latent_input self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) if not self.latent_input: - self.input_hint_block = nn.Sequential( - operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) - ) + if self.mistoline: + self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations) + else: + self.input_hint_block = nn.Sequential( + operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) + ) def forward_orig( - self, - img: Tensor, - img_ids: Tensor, - controlnet_cond: Tensor, - txt: Tensor, - txt_ids: Tensor, - timesteps: Tensor, - y: Tensor, - guidance: Tensor = None, - control_type: Tensor | list[Never] | None = None, + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + control_type: Tensor | list[Never] | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) - if not self.latent_input: - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) controlnet_cond = self.pos_embed_input(controlnet_cond) img = img + controlnet_cond @@ -87,7 +133,7 @@ class ControlNetFlux(Flux): if self.controlnet_mode_embedder is not None and len(control_type) > 0: control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1)) txt = torch.cat([control_cond, txt], dim=1) - txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1) + txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1) ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) @@ -104,13 +150,13 @@ class ControlNetFlux(Flux): for i in range(len(self.single_blocks)): img = self.single_blocks[i](img, vec=vec, pe=pe) - controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),) + controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1]:, ...]),) repeat = math.ceil(self.main_model_double / len(controlnet_double)) if self.latent_input: out_input = () for x in controlnet_double: - out_input += (x,) * repeat + out_input += (x,) * repeat else: out_input = (controlnet_double * repeat) @@ -120,7 +166,7 @@ class ControlNetFlux(Flux): out_output = () if self.latent_input: for x in controlnet_single: - out_output += (x,) * repeat + out_output += (x,) * repeat else: out_output = (controlnet_single * repeat) out["output"] = out_output[:self.main_model_single] @@ -130,9 +176,14 @@ class ControlNetFlux(Flux): patch_size = 2 if self.latent_input: hint = common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) - hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + elif self.mistoline: + hint = hint * 2.0 - 1.0 + hint = self.input_cond_block(hint) else: hint = hint * 2.0 - 1.0 + hint = self.input_hint_block(hint) + + hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) bs, c, h, w = x.shape x = common_dit.pad_to_patch_size(x, (patch_size, patch_size)) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index c5d7ed119..50b7ae7bb 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -842,6 +842,11 @@ class UNetModel(nn.Module): t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) + if "emb_patch" in transformer_patches: + patch = transformer_patches["emb_patch"] + for p in patch: + emb = p(emb, self.model_channels, transformer_options) + if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) diff --git a/comfy/model_management.py b/comfy/model_management.py index a09c0476f..9b0c0fd8e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -495,7 +495,7 @@ def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]: shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded: - can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) + can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False for x in sorted(can_unload): diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py index 240891057..5b105b35b 100644 --- a/comfy/text_encoders/long_clipl.py +++ b/comfy/text_encoders/long_clipl.py @@ -1,7 +1,6 @@ from comfy import sd1_clip -import os -from comfy.component_model.files import get_path_as_dict +from ..component_model.files import get_path_as_dict class LongClipTokenizer_(sd1_clip.SDTokenizer): diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py new file mode 100644 index 000000000..dcb46f0e0 --- /dev/null +++ b/comfy_extras/nodes_lora_extract.py @@ -0,0 +1,91 @@ +import torch +import comfy.model_management +import comfy.utils +import folder_paths +import os +import logging + +CLAMP_QUANTILE = 0.99 + +def extract_lora(diff, rank): + conv2d = (len(diff.shape) == 4) + kernel_size = None if not conv2d else diff.size()[2:4] + conv2d_3x3 = conv2d and kernel_size != (1, 1) + out_dim, in_dim = diff.size()[0:2] + rank = min(rank, in_dim, out_dim) + + if conv2d: + if conv2d_3x3: + diff = diff.flatten(start_dim=1) + else: + diff = diff.squeeze() + + + U, S, Vh = torch.linalg.svd(diff.float()) + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + if conv2d: + U = U.reshape(out_dim, rank, 1, 1) + Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) + return (U, Vh) + +class LoraSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}), + "rank": ("INT", {"default": 8, "min": 1, "max": 1024, "step": 1}), + }, + "optional": {"model_diff": ("MODEL",),}, + } + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save(self, filename_prefix, rank, model_diff=None): + if model_diff is None: + return {} + + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + + output_sd = {} + prefix_key = "diffusion_model." + stored = set() + + comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) + sd = model_diff.model_state_dict(filter_prefix=prefix_key) + + for k in sd: + if k.endswith(".weight"): + weight_diff = sd[k] + if weight_diff.ndim < 2: + continue + try: + out = extract_lora(weight_diff, rank) + output_sd["{}.lora_up.weight".format(k[:-7])] = out[0].contiguous().half().cpu() + output_sd["{}.lora_down.weight".format(k[:-7])] = out[1].contiguous().half().cpu() + except: + logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k)) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None) + return {} + +NODE_CLASS_MAPPINGS = { + "LoraSave": LoraSave +}