mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 11:50:16 +08:00
Fix #45 now test 3.10 and 3.12 images. NVIDIA doesn't support 3.11 at all!
This commit is contained in:
parent
f6d3962c77
commit
2a881a768e
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@ -11,7 +11,7 @@ on: [ push ]
|
|||||||
jobs:
|
jobs:
|
||||||
build_and_execute_linux:
|
build_and_execute_linux:
|
||||||
environment: "Testing"
|
environment: "Testing"
|
||||||
name: Installation, Unit and Workflow Tests for Linux
|
name: Installation, Unit and Workflow Tests for Linux (${{ matrix.runner.friendly_name }})
|
||||||
runs-on: ${{ matrix.runner.labels }}
|
runs-on: ${{ matrix.runner.labels }}
|
||||||
container: ${{ matrix.runner.container }}
|
container: ${{ matrix.runner.container }}
|
||||||
strategy:
|
strategy:
|
||||||
@ -19,7 +19,17 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
runner:
|
runner:
|
||||||
- labels: [self-hosted, Linux, X64, cuda-3090-24gb]
|
- labels: [self-hosted, Linux, X64, cuda-3090-24gb]
|
||||||
|
friendly_name: "Python 3.12 CUDA 12.9.1 Torch-TensorRT 2.8.0a0"
|
||||||
|
container: "nvcr.io/nvidia/pytorch:25.06-py3"
|
||||||
|
- labels: [self-hosted, Linux, X64, cuda-3090-24gb]
|
||||||
|
friendly_name: "(LTS) Python 3.12 CUDA 12.8.1.012 Torch-TensorRT 2.7.0a0"
|
||||||
container: "nvcr.io/nvidia/pytorch:25.03-py3"
|
container: "nvcr.io/nvidia/pytorch:25.03-py3"
|
||||||
|
- labels: [self-hosted, Linux, X64, cuda-3090-24gb]
|
||||||
|
friendly_name: "Python 3.10 CUDA 12.6.2 Torch-TensorRT 2.5.0a0"
|
||||||
|
container: "nvcr.io/nvidia/pytorch:24.10-py3"
|
||||||
|
- labels: [self-hosted, Linux, X64, cuda-3090-24gb]
|
||||||
|
friendly_name: "Python 3.10 CUDA 12.3.2 Torch-TensorRT 2.2.0a0"
|
||||||
|
container: "nvcr.io/nvidia/pytorch:23.12-py3"
|
||||||
steps:
|
steps:
|
||||||
- run: |
|
- run: |
|
||||||
apt-get update
|
apt-get update
|
||||||
|
|||||||
13
.pylintrc
13
.pylintrc
@ -600,7 +600,6 @@ disable=raw-checker-failed,
|
|||||||
preferred-module,
|
preferred-module,
|
||||||
misplaced-future,
|
misplaced-future,
|
||||||
shadowed-import,
|
shadowed-import,
|
||||||
deprecated-module,
|
|
||||||
missing-timeout,
|
missing-timeout,
|
||||||
useless-with-lock,
|
useless-with-lock,
|
||||||
bare-except,
|
bare-except,
|
||||||
@ -623,11 +622,6 @@ disable=raw-checker-failed,
|
|||||||
unspecified-encoding,
|
unspecified-encoding,
|
||||||
forgotten-debug-statement,
|
forgotten-debug-statement,
|
||||||
method-cache-max-size-none,
|
method-cache-max-size-none,
|
||||||
deprecated-method,
|
|
||||||
deprecated-argument,
|
|
||||||
deprecated-class,
|
|
||||||
deprecated-decorator,
|
|
||||||
deprecated-attribute,
|
|
||||||
bad-format-string-key,
|
bad-format-string-key,
|
||||||
unused-format-string-key,
|
unused-format-string-key,
|
||||||
bad-format-string,
|
bad-format-string,
|
||||||
@ -678,7 +672,12 @@ disable=raw-checker-failed,
|
|||||||
# either give multiple identifier separated by comma (,) or put this option
|
# either give multiple identifier separated by comma (,) or put this option
|
||||||
# multiple time (only on the command line, not in the configuration file where
|
# multiple time (only on the command line, not in the configuration file where
|
||||||
# it should appear only once). See also the "--disable" option for examples.
|
# it should appear only once). See also the "--disable" option for examples.
|
||||||
enable=
|
enable=deprecated-module,
|
||||||
|
deprecated-method,
|
||||||
|
deprecated-argument,
|
||||||
|
deprecated-class,
|
||||||
|
deprecated-decorator,
|
||||||
|
deprecated-attribute
|
||||||
|
|
||||||
[METHOD_ARGS]
|
[METHOD_ARGS]
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
from asyncio import Task, Future
|
from asyncio import Task, Future
|
||||||
from typing import override, NamedTuple, Optional, AsyncIterable
|
from typing import NamedTuple, Optional, AsyncIterable
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from .client_types import V1QueuePromptResponse, ProgressNotification
|
from .client_types import V1QueuePromptResponse, ProgressNotification
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
||||||
|
|||||||
@ -3,11 +3,11 @@ from __future__ import annotations
|
|||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override, TYPE_CHECKING
|
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired, override
|
||||||
|
|
||||||
from .comfy_types import UnetWrapperFunction
|
from .comfy_types import UnetWrapperFunction
|
||||||
from .latent_formats import LatentFormat
|
from .latent_formats import LatentFormat
|
||||||
@ -109,8 +109,6 @@ class HooksSupportStub(HooksSupport, metaclass=ABCMeta):
|
|||||||
if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable):
|
if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable):
|
||||||
model.current_patcher = self
|
model.current_patcher = self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_state(self, *args, **kwargs):
|
def prepare_state(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -219,7 +217,6 @@ class ModelManageableStub(HooksSupportStub, TrainingSupportStub, ModelManageable
|
|||||||
:see: PatchSupport
|
:see: PatchSupport
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True,
|
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True,
|
||||||
force_patch_weights: bool = False) -> torch.nn.Module:
|
force_patch_weights: bool = False) -> torch.nn.Module:
|
||||||
|
|||||||
@ -8,14 +8,8 @@ import os
|
|||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Tuple, Sequence, TypeVar, Callable, Optional, Union
|
||||||
try:
|
|
||||||
from importlib.resources.abc import Traversable # pylint: disable=no-name-in-module
|
|
||||||
except ImportError:
|
|
||||||
from importlib.abc import Traversable # pylint: disable=no-name-in-module
|
|
||||||
from typing import Tuple, Sequence, TypeVar, Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
||||||
@ -27,6 +21,11 @@ from .component_model import files
|
|||||||
from .component_model.files import get_path_as_dict, get_package_as_path
|
from .component_model.files import get_path_as_dict, get_package_as_path
|
||||||
from .text_encoders.spiece_tokenizer import SPieceTokenizer
|
from .text_encoders.spiece_tokenizer import SPieceTokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from importlib.resources.abc import Traversable # pylint: disable=no-name-in-module
|
||||||
|
except ImportError:
|
||||||
|
from importlib.abc import Traversable # pylint: disable=no-name-in-module, deprecated-class
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def gen_empty_tokens(special_tokens, length):
|
def gen_empty_tokens(special_tokens, length):
|
||||||
@ -547,7 +546,7 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
|
|||||||
|
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path: torch.Tensor | bytes | bytearray | memoryview | str | Path | Traversable = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data=None, tokenizer_args=None):
|
def __init__(self, tokenizer_path: Optional[Union[torch.Tensor, bytes, bytearray, memoryview, str, Path, Traversable]] = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data=None, tokenizer_args=None):
|
||||||
if tokenizer_data is None:
|
if tokenizer_data is None:
|
||||||
tokenizer_data = dict()
|
tokenizer_data = dict()
|
||||||
if tokenizer_args is None:
|
if tokenizer_args is None:
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import comfy.model_patcher
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||||
# get values from args
|
# get values from args
|
||||||
@ -32,7 +33,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||||
if easycache.skip_current_step:
|
if easycache.skip_current_step:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
logger.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||||
return easycache.apply_cache_diff(x, uuids)
|
return easycache.apply_cache_diff(x, uuids)
|
||||||
if easycache.initial_step:
|
if easycache.initial_step:
|
||||||
easycache.first_cond_uuid = uuids[0]
|
easycache.first_cond_uuid = uuids[0]
|
||||||
@ -46,13 +47,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
easycache.cumulative_change_rate += approx_output_change_rate
|
easycache.cumulative_change_rate += approx_output_change_rate
|
||||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
logger.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
# other conds should also skip this step, and instead use their cached values
|
# other conds should also skip this step, and instead use their cached values
|
||||||
easycache.skip_current_step = True
|
easycache.skip_current_step = True
|
||||||
return easycache.apply_cache_diff(x, uuids)
|
return easycache.apply_cache_diff(x, uuids)
|
||||||
else:
|
else:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
logger.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
easycache.cumulative_change_rate = 0.0
|
easycache.cumulative_change_rate = 0.0
|
||||||
|
|
||||||
output: torch.Tensor = executor(*args, **kwargs)
|
output: torch.Tensor = executor(*args, **kwargs)
|
||||||
@ -65,11 +66,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||||
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
logger.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
||||||
if input_change is not None:
|
if input_change is not None:
|
||||||
easycache.relative_transformation_rate = output_change / input_change
|
easycache.relative_transformation_rate = output_change / input_change
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
logger.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||||
# TODO: allow cache_diff to be offloaded
|
# TODO: allow cache_diff to be offloaded
|
||||||
easycache.update_cache_diff(output, next_x_prev, uuids)
|
easycache.update_cache_diff(output, next_x_prev, uuids)
|
||||||
if has_first_cond_uuid:
|
if has_first_cond_uuid:
|
||||||
@ -77,7 +78,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
||||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
logger.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||||
@ -102,13 +103,13 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
|||||||
easycache.cumulative_change_rate += approx_output_change_rate
|
easycache.cumulative_change_rate += approx_output_change_rate
|
||||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
logger.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
# other conds should also skip this step, and instead use their cached values
|
# other conds should also skip this step, and instead use their cached values
|
||||||
easycache.skip_current_step = True
|
easycache.skip_current_step = True
|
||||||
return easycache.apply_cache_diff(x)
|
return easycache.apply_cache_diff(x)
|
||||||
else:
|
else:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
logger.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||||
easycache.cumulative_change_rate = 0.0
|
easycache.cumulative_change_rate = 0.0
|
||||||
output: torch.Tensor = executor(*args, **kwargs)
|
output: torch.Tensor = executor(*args, **kwargs)
|
||||||
if easycache.has_output_prev_norm():
|
if easycache.has_output_prev_norm():
|
||||||
@ -120,18 +121,18 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
|||||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||||
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
logger.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
||||||
if input_change is not None:
|
if input_change is not None:
|
||||||
easycache.relative_transformation_rate = output_change / input_change
|
easycache.relative_transformation_rate = output_change / input_change
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
|
logger.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||||
# TODO: allow cache_diff to be offloaded
|
# TODO: allow cache_diff to be offloaded
|
||||||
easycache.update_cache_diff(output, next_x_prev)
|
easycache.update_cache_diff(output, next_x_prev)
|
||||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
|
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
|
||||||
easycache.output_prev_subsampled = easycache.subsample(output)
|
easycache.output_prev_subsampled = easycache.subsample(output)
|
||||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
logger.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
|
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
|
||||||
@ -152,22 +153,22 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
|
|||||||
# clone and prepare timesteps
|
# clone and prepare timesteps
|
||||||
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
||||||
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
|
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
|
||||||
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
|
logger.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
|
||||||
return executor(*args, **kwargs)
|
return executor(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
easycache = guider.model_options['transformer_options']['easycache']
|
easycache = guider.model_options['transformer_options']['easycache']
|
||||||
output_change_rates = easycache.output_change_rates
|
output_change_rates = easycache.output_change_rates
|
||||||
approx_output_change_rates = easycache.approx_output_change_rates
|
approx_output_change_rates = easycache.approx_output_change_rates
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
logger.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
||||||
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
logger.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
||||||
total_steps = len(args[3])-1
|
total_steps = len(args[3])-1
|
||||||
# catch division by zero for log statement; sucks to crash after all sampling is done
|
# catch division by zero for log statement; sucks to crash after all sampling is done
|
||||||
try:
|
try:
|
||||||
speedup = total_steps/(total_steps-easycache.total_steps_skipped)
|
speedup = total_steps/(total_steps-easycache.total_steps_skipped)
|
||||||
except ZeroDivisionError:
|
except ZeroDivisionError:
|
||||||
speedup = 1.0
|
speedup = 1.0
|
||||||
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
|
logger.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
|
||||||
easycache.reset()
|
easycache.reset()
|
||||||
guider.model_options = orig_model_options
|
guider.model_options = orig_model_options
|
||||||
|
|
||||||
@ -298,7 +299,7 @@ class EasyCacheHolder:
|
|||||||
return True
|
return True
|
||||||
if metadata == self.state_metadata:
|
if metadata == self.state_metadata:
|
||||||
return True
|
return True
|
||||||
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
logger.warning(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
||||||
self.reset()
|
self.reset()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -432,7 +433,7 @@ class LazyCacheHolder:
|
|||||||
return True
|
return True
|
||||||
if metadata == self.state_metadata:
|
if metadata == self.state_metadata:
|
||||||
return True
|
return True
|
||||||
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
logger.warning(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
||||||
self.reset()
|
self.reset()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user