Improve logging and tracing for validation errors

This commit is contained in:
doctorpangloss 2024-07-16 12:26:30 -07:00
parent a20bf8134d
commit 72baecad87
12 changed files with 167 additions and 79 deletions

View File

@ -1,13 +1,15 @@
import glob
import json
import os
import re
import uuid
import glob
import shutil
import uuid
from aiohttp import web
from .app_settings import AppSettings
from ..cli_args import args
from ..cmd.folder_paths import user_directory
from .app_settings import AppSettings
class UserManager():
@ -17,9 +19,6 @@ class UserManager():
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(self.users_file):
@ -129,7 +128,7 @@ class UserManager():
return web.json_response(results)
def get_user_data_path(request, check_exists = False, param = "file"):
def get_user_data_path(request, check_exists=False, param="file"):
file = request.match_info.get(param, None)
if not file:
return web.Response(status=400)

View File

@ -8,22 +8,22 @@ import sys
import threading
import traceback
import typing
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import lazy_object_proxy
import torch
from opentelemetry.trace import get_current_span, StatusCode, Status
from typing_extensions import TypedDict
from .main_pre import tracer
from .. import interruption
from .. import model_management
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, ExecutionContext
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place.
@ -492,7 +492,7 @@ class PromptExecutor:
model_management.unload_all_models()
def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], typing.Any]:
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
# todo: this should check if LoadImage / LoadImageMask paths exist
# todo: or, nodes should provide a way to validate their values
unique_id = item
@ -506,11 +506,12 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required']
error: ValidationErrorDict
errors = []
valid = True
# todo: investigate if these are at the right indent level
info = None
info: Optional[InputTypeSpec] = None
val = None
validate_function_inputs = []
@ -531,7 +532,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
continue
val = inputs[x]
info = required_inputs[x]
info: InputTypeSpec = required_inputs[x]
type_input = info[0]
if isinstance(val, list):
if len(val) != 2:
@ -593,7 +594,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
"linked_node": val
}
}]
validated[o_id] = (False, reasons, o_id)
validated[o_id] = ValidateInputsTuple(False, reasons, o_id)
continue
else:
try:
@ -622,10 +623,11 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
continue
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
has_min_max: IntSpecOptions | FloatSpecOptions = info[1]
if "min" in has_min_max and val < has_min_max["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"message": "Value {} smaller than min of {}".format(val, has_min_max["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
@ -635,10 +637,10 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
}
errors.append(error)
continue
if "max" in info[1] and val > info[1]["max"]:
if "max" in has_min_max and val > has_min_max["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"message": "Value {} bigger than max of {}".format(val, has_min_max["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
@ -706,9 +708,9 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
ret = ValidateInputsTuple(False, errors, unique_id)
else:
ret = (True, [], unique_id)
ret = ValidateInputsTuple(True, [], unique_id)
validated[unique_id] = ret
return ret
@ -721,23 +723,25 @@ def full_type_name(klass):
return module + '.' + klass.__qualname__
class ValidationErrorExtraInfoDict(TypedDict):
exception_type: str
traceback: List[str]
class ValidationErrorDict(TypedDict):
type: str
message: str
details: str
extra_info: ValidationErrorExtraInfoDict | dict
ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]]
@tracer.start_as_current_span("Validate Prompt")
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
res = _validate_prompt(prompt)
if not res.valid:
span = get_current_span()
span.set_status(Status(StatusCode.ERROR))
if res.error is not None and len(res.error) > 0:
span.set_attributes({
f"error.{k}": v for k, v in res.error.items()
})
if len(res.node_errors) > 0:
for node_id, node_error in res.node_errors.items():
for node_error_field, node_error_value in node_error.items():
if isinstance(node_error_value, (str, bool, int, float)):
span.set_attribute("node_errors.{node_id}.{node_error_field}", node_error_value)
return res
def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
@ -747,7 +751,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
return ValidationTuple(False, error, [], [])
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
@ -758,7 +762,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
return ValidationTuple(False, error, [], [])
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
@ -770,15 +774,15 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"details": "",
"extra_info": {}
}
return False, error, [], []
return ValidationTuple(False, error, [], [])
good_outputs = set()
errors = []
node_errors = {}
validated = {}
node_errors: typing.Dict[str, NodeErrorsDictValue] = {}
validated: typing.Dict[str, ValidateInputsTuple] = {}
for o in outputs:
valid = False
reasons = []
reasons: List[ValidationErrorDict] = []
try:
m = validate_inputs(prompt, o, validated)
valid = m[0]
@ -796,7 +800,7 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"traceback": traceback.format_tb(tb)
}
}]
validated[o] = (False, reasons, o)
validated[o] = ValidateInputsTuple(False, reasons, o)
if valid is True:
good_outputs.add(o)
@ -841,9 +845,9 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
"extra_info": {}
}
return False, error, list(good_outputs), node_errors
return ValidationTuple(False, error, list(good_outputs), node_errors)
return True, None, list(good_outputs), node_errors
return ValidationTuple(True, None, list(good_outputs), node_errors)
class PromptQueue(AbstractPromptQueue):

View File

@ -7,9 +7,9 @@ import shutil
import threading
import time
from .extra_model_paths import load_extra_path_config
# main_pre must be the earliest import since it suppresses some spurious warnings
from .main_pre import args
from .extra_model_paths import load_extra_path_config
from .. import model_management
from ..analytics.analytics import initialize_event_tracking
from ..cmd import cuda_malloc
@ -223,7 +223,10 @@ async def main():
def entrypoint():
asyncio.run(main())
try:
asyncio.run(main())
except KeyboardInterrupt as keyboard_interrupt:
logging.info(f"Gracefully shutting down due to {keyboard_interrupt}")
if __name__ == "__main__":

View File

@ -492,9 +492,6 @@ class PromptServer(ExecutorToClientProgress):
@routes.post("/prompt")
async def post_prompt(request):
logging.info("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)

View File

@ -1,11 +1,13 @@
from __future__ import annotations # for Python 3.7-3.9
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
import typing
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple, List
import PIL.Image
from typing_extensions import NotRequired, TypedDict
from .queue_types import BinaryEventTypes
from ..nodes.package_typing import InputTypeSpec
class ExecInfo(TypedDict):
@ -85,3 +87,46 @@ class ExecutorToClientProgress(Protocol):
:return:
"""
pass
ExceptionTypes = Literal["custom_validation_failed", "value_not_in_list", "value_bigger_than_max", "value_smaller_than_min", "invalid_input_type", "exception_during_inner_validation", "return_type_mismatch", "bad_linked_input", "required_input_missing", "invalid_prompt", "prompt_no_outputs", "exception_during_validation", "prompt_outputs_failed_validation"]
class ValidationErrorExtraInfoDict(TypedDict, total=False):
exception_type: NotRequired[str]
traceback: NotRequired[List[str]]
dependent_outputs: NotRequired[List[str]]
class_type: NotRequired[str]
input_name: NotRequired[str]
input_config: NotRequired[typing.Dict[str, InputTypeSpec]]
received_value: NotRequired[typing.Any]
linked_node: NotRequired[str]
traceback: NotRequired[str]
exception_message: NotRequired[str]
exception_type: NotRequired[str]
class ValidationErrorDict(TypedDict):
type: str
message: str
details: str
extra_info: list[typing.Never] | ValidationErrorExtraInfoDict
class NodeErrorsDictValue(TypedDict, total=False):
dependent_outputs: NotRequired[List[str]]
errors: List[ValidationErrorDict]
class_type: str
class ValidationTuple(typing.NamedTuple):
valid: bool
error: Optional[ValidationErrorDict]
good_output_node_ids: List[str]
node_errors: list[typing.Never] | typing.Dict[str, NodeErrorsDictValue]
class ValidateInputsTuple(typing.NamedTuple):
valid: bool
errors: List[ValidationErrorDict]
unique_id: str

View File

@ -89,8 +89,6 @@ class BaseModel(torch.nn.Module):
self.adm_channels = 0
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t

View File

@ -315,6 +315,8 @@ KNOWN_CONTROLNETS: Final[List[Downloadable]] = [
HuggingFile("xinsir/controlnet-openpose-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-openpose-sdxl-1.0.safetensors"),
HuggingFile("xinsir/anime-painter", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-anime-painter-scribble-sdxl-1.0.safetensors"),
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
]
KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [

View File

@ -10,9 +10,11 @@ from typing import Literal, List
import psutil
import torch
from opentelemetry.trace import get_current_span
from . import interruption
from .cli_args import args
from .cmd.main_pre import tracer
from .model_management_types import ModelManageable
model_management_lock = RLock()
@ -356,6 +358,12 @@ class LoadedModel:
def __eq__(self, other):
return self.model is other.model
def __str__(self):
if self.model is not None:
return f"<LoadedModel {str(self.model)}>"
else:
return f"<LoadedModel>"
def minimum_inference_memory():
return (1024 * 1024 * 1024)
@ -392,9 +400,12 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> b
return unload_weight
def free_memory(memory_required, device, keep_loaded=[]):
@tracer.start_as_current_span("Free Memory")
def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
span = get_current_span()
span.set_attribute("memory_required", memory_required)
with model_management_lock:
unloaded_model = []
unloaded_models: List[LoadedModel] = []
can_unload = []
for i in range(len(current_loaded_models) - 1, -1, -1):
@ -410,12 +421,12 @@ def free_memory(memory_required, device, keep_loaded=[]):
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_model.append(i)
unloaded_models.append(i)
for i in sorted(unloaded_model, reverse=True):
for i in sorted(unloaded_models, reverse=True):
current_loaded_models.pop(i)
if len(unloaded_model) > 0:
if len(unloaded_models) > 0:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
@ -423,10 +434,16 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
span.set_attribute("unloaded_models", list(map(str, unloaded_models)))
return unloaded_models
@tracer.start_as_current_span("Load Models GPU")
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
global vram_state
span = get_current_span()
if memory_required != 0:
span.set_attribute("memory_required", memory_required)
with model_management_lock:
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
@ -452,19 +469,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
loaded.currently_used = True
models_already_loaded.append(loaded)
if loaded is None:
if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
models_freed: List[LoadedModel] = []
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded)
models_freed += free_memory(extra_mem, d, models_already_loaded)
return
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {}
for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
@ -472,7 +486,12 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
span.set_attribute("models_to_load", list(map(str, models_to_load)))
span.set_attribute("models_freed", list(map(str, models_freed)))
logging.info(f"Models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded

View File

@ -2,13 +2,14 @@ import copy
import inspect
import logging
import uuid
from typing import Optional
import torch
from . import model_management
from . import utils
from .types import UnetWrapperFunction
from .model_management_types import ModelManageable
from .types import UnetWrapperFunction
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
@ -60,14 +61,16 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl
model_options["disable_cfg1_optimization"] = True
return model_options
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
if disable_cfg1_optimization:
model_options["disable_cfg1_optimization"] = True
return model_options
class ModelPatcher(ModelManageable):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size
self.model = model
self.patches = {}
@ -87,6 +90,7 @@ class ModelPatcher(ModelManageable):
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.patches_uuid = uuid.uuid4()
self.ckpt_name = ckpt_name
self._lowvram_patch_counter = 0
@property
@ -105,6 +109,7 @@ class ModelPatcher(ModelManageable):
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
n.ckpt_name = self.ckpt_name
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@ -578,3 +583,9 @@ class ModelPatcher(ModelManageable):
@property
def current_device(self) -> torch.device:
return self._current_device
def __str__(self):
if self.ckpt_name is not None:
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__})>"
else:
return f"<ModelPatcher for {self.model.__class__.__name__}>"

View File

@ -1,10 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing_extensions import TypedDict, NotRequired, Generic
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List, Type
from typing_extensions import TypedDict, NotRequired
T = TypeVar('T')
@ -51,7 +52,7 @@ StringSpec = Tuple[Literal["STRING"], StringSpecOptions]
BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]]
ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]]
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
@ -73,6 +74,7 @@ ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
IsChangedMethod = Callable[[Type[Any], ...], str]
class FunctionReturnsUIVariables(TypedDict):
ui: dict
result: NotRequired[Sequence[Any]]
@ -123,6 +125,10 @@ class CustomNode(Protocol):
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
@classmethod
def __call__(cls, *args, **kwargs) -> 'CustomNode':
...
@dataclass
class ExportedNodes:

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import dataclasses
import logging
import os.path
from enum import Enum
from typing import Any, Optional
@ -16,18 +17,18 @@ from . import model_detection
from . import model_management
from . import model_patcher
from . import model_sampling
from . import sa_t5
from . import sd1_clip
from . import sd2_clip
from . import sd3_clip
from . import sdxl_clip
from . import utils
from .ldm.audio.autoencoder import AudioOobleckVAE
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.audio.autoencoder import AudioOobleckVAE
from .t2i_adapter import adapter
from .taesd import taesd
from . import sd3_clip
from . import sa_t5
from .text_encoders import aura_t5
@ -228,7 +229,7 @@ class VAE:
self.latent_channels = 64
self.output_channels = 2
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.downscale_ratio = 2048
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -302,7 +303,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def decode(self, samples_in):
try:
@ -558,9 +559,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
logging.debug("left over keys: {}".format(left_over))
if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_model_gpu(_model_patcher)
return (_model_patcher, clip, vae, clipvision)
@ -568,7 +568,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): # load unet in diffusers or regular format
#Allow loading unets from checkpoint files
# Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0:
@ -583,7 +583,7 @@ def load_unet_state_dict(sd): # load unet in diffusers or regular format
new_sd = sd
else:
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
if new_sd is not None: #diffusers mmdit
if new_sd is not None: # diffusers mmdit
model_config = model_detection.model_config_from_unet(new_sd, "")
if model_config is None:
return None

View File

@ -46,6 +46,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@pytest.mark.asyncio
def test_known_repos(tmp_path_factory):
prev_hub_cache = os.getenv("HF_HUB_CACHE")
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
from comfy.cmd import folder_paths
from comfy.cmd.folder_paths import FolderPathsTuple
from comfy.model_downloader import get_huggingface_repo_list, \
@ -57,9 +60,10 @@ def test_known_repos(tmp_path_factory):
test_repo_id = "doctorpangloss/comfyui_downloader_test"
prev_huggingface = folder_paths.folder_names_and_paths["huggingface"]
prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"]
prev_hub_cache = os.getenv("HF_HUB_CACHE")
_delete_repo_from_huggingface_cache(test_repo_id)
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
args.disable_known_models = False
try:
folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""})
folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})