Improve tests, fix issues with alternate filenames, improve group offloading support for transformers models

This commit is contained in:
doctorpangloss 2025-09-18 13:25:08 -07:00
parent 79b8723f61
commit bc201cea4d
14 changed files with 203 additions and 87 deletions

View File

@ -242,7 +242,7 @@ ComfyUI LTS supports text and multi-modal LLM models from the `transformers` eco
In this example, LLAVA-NEXT (LLAVA 1.6) is prompted to describe an image.
You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-3-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json).
You can try the [LLAVA-NEXT](tests/inference/workflows/llava-0.json), [Phi-3](tests/inference/workflows/phi-4-0.json), and two [translation](tests/inference/workflows/translation-0.json) [workflows](tests/inference/workflows/translation-1.json).
# SVG Conversion and String Saving

View File

@ -130,7 +130,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true",
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
help="Force ComfyUI to aggressively offload to regular ram instead of keeping models in VRAM when it can.")
parser.add_argument("--deterministic", action="store_true",
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")

View File

@ -12,9 +12,12 @@ from typing import Optional, Any, Callable
import torch
import transformers
from huggingface_hub.errors import EntryNotFoundError
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
PretrainedConfig, TextStreamer, LogitsProcessor
from huggingface_hub import hf_api
from huggingface_hub.file_download import hf_hub_download
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
@ -25,7 +28,7 @@ from .. import model_management
from ..component_model.tensor_types import RGBImageBatch
from ..model_downloader import get_or_download_huggingface_repo
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
from ..model_management_types import ModelManageable
from ..model_management_types import ModelManageableStub
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
logger = logging.getLogger(__name__)
@ -37,7 +40,7 @@ _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'}
class TransformersManagedModel(ModelManageable, LanguageModel):
class TransformersManagedModel(ModelManageableStub, LanguageModel):
def __init__(
self,
repo_id: str,
@ -69,7 +72,20 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
hub_kwargs["subfolder"] = subfolder
repo_id = ckpt_name
with comfy_tqdm():
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
ckpt_name = get_or_download_huggingface_repo(repo_id)
if config_dict is None:
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
elif isinstance(config_dict, PretrainedConfig):
config_dict: dict = config_dict.to_dict()
else:
config_dict = {}
try:
model_type = config_dict["model_type"]
except KeyError:
logger.debug(f"Configuration was missing for repo_id={repo_id}")
model_type = ""
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,
@ -77,19 +93,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
**hub_kwargs
}
# compute bitsandbytes configuration
try:
import bitsandbytes
except ImportError:
pass
if config_dict is None:
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
elif isinstance(config_dict, PretrainedConfig):
config_dict: dict = config_dict.to_dict()
model_type = config_dict["model_type"]
# language models prefer to use bfloat16 over float16
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
kwargs_to_try = ({"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
"low_cpu_mem_usage": True,
"device_map": str(unet_offload_device()), }, {})

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union
from typing import Protocol, List, Dict, Optional, NamedTuple, Callable, Literal, Any, TypeAlias, Union, runtime_checkable
import torch
PatchOffset = tuple[int, int, int]
@ -31,6 +31,7 @@ class PatchTuple(NamedTuple):
ModelPatchesDictValue: TypeAlias = list[Union[PatchTuple, PatchWeightTuple]]
@runtime_checkable
class PatchSupport(Protocol):
"""
Defines the interface for a model that supports LoRA patching.

View File

@ -725,15 +725,15 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]:
def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]:
with comfy_tqdm():
return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs)
return _get_or_download_huggingface_repo(repo_id, cache_dirs, local_dirs, force=force, subset=subset)
def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]:
def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None, force: bool = False, subset: bool = False) -> Optional[str]:
cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache")
local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface")
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id)
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id, subset=subset)
local_dirs_cache_hit = len(local_dirs_snapshots) > 0
cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0
@ -742,25 +742,25 @@ def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] =
# if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download
if args.force_hf_local_dir_mode:
# todo: we still have to figure out a way to download things to the right places by default
if len(local_dirs_snapshots) > 0:
if len(local_dirs_snapshots) > 0 and not force:
return local_dirs_snapshots[0]
elif not args.disable_known_models:
destination = os.path.join(local_dirs[0], repo_id)
logger.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
return snapshot_download(repo_id, local_dir=destination)
return snapshot_download(repo_id, local_dir=destination, force_download=force)
snapshots = local_dirs_snapshots + cache_dirs_snapshots
if len(snapshots) > 0:
if len(snapshots) > 0 and not force:
return snapshots[0]
elif not args.disable_known_models:
logger.debug(f"downloading repo_id={repo_id}")
return snapshot_download(repo_id)
return snapshot_download(repo_id, force_download=force)
# this repo was not found
return None
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id):
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id, subset=False):
local_dirs_snapshots = []
cache_dirs_snapshots = []
# find all the pre-existing downloads for this repo_id
@ -772,13 +772,12 @@ def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_i
if len(repo_files) > 0:
for local_dir in local_dirs:
local_path = Path(local_dir) / repo_id
local_files = set(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
local_files = frozenset(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
# fix path representation
local_files = set(f.replace("\\", "/") for f in local_files)
local_files = frozenset(f.replace("\\", "/") for f in local_files)
# remove .huggingface
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
# local_files.issubsetof(repo_files)
if len(local_files) > 0 and local_files.issubset(repo_files):
local_files = frozenset(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
if len(local_files) > 0 and ((subset and local_files.issubset(repo_files)) or (not subset and repo_files.issubset(local_files))):
local_dirs_snapshots.append(str(local_path))
else:
# an empty repository or unknown repository info, trust that if the directory exists, it matches

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import copy
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple, override, TYPE_CHECKING
import torch
import torch.nn
@ -11,22 +12,48 @@ from typing_extensions import TypedDict, NotRequired
from .comfy_types import UnetWrapperFunction
from .latent_formats import LatentFormat
if TYPE_CHECKING:
from .hooks import EnumHookMode
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
LatentFormatT = TypeVar('LatentFormatT', bound=LatentFormat)
@runtime_checkable
class DeviceSettable(Protocol):
@property
def device(self) -> torch.device:
...
@device.setter
def device(self, value: torch.device):
...
device: torch.device
class HooksSupport(Protocol, metaclass=ABCMeta):
@runtime_checkable
class HooksSupport(Protocol):
wrappers: dict[str, dict[str, list[Callable]]]
callbacks: dict[str, dict[str, list[Callable]]]
hook_mode: "EnumHookMode"
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): ...
def model_patches_models(self) -> list[ModelManageableT]: ...
def restore_hook_patches(self): ...
def cleanup(self): ...
def pre_run(self): ...
def prepare_state(self, *args, **kwargs): ...
def register_all_hook_patches(self, a, b, c, d): ...
def get_nested_additional_models(self): ...
def apply_hooks(self, *args, **kwargs): ...
def add_wrapper(self, wrapper_type: str, wrapper: Callable): ...
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): ...
class HooksSupportStub(HooksSupport, metaclass=ABCMeta):
def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options):
return
@ -82,6 +109,8 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
if isinstance(model, BaseModel) or hasattr(model, "current_patcher") and isinstance(self, ModelManageable):
model.current_patcher = self
def prepare_state(self, *args, **kwargs):
pass
@ -94,8 +123,22 @@ class HooksSupport(Protocol, metaclass=ABCMeta):
def apply_hooks(self, *args, **kwargs):
return {}
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
self.add_wrapper_with_key(wrapper_type, None, wrapper)
class TrainingSupport(Protocol, metaclass=ABCMeta):
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
@runtime_checkable
class TrainingSupport(Protocol):
def set_model_compute_dtype(self, dtype: torch.dtype): ...
def add_weight_wrapper(self, name, function): ...
class TrainingSupportStub(TrainingSupport, metaclass=ABCMeta):
def set_model_compute_dtype(self, dtype: torch.dtype):
return
@ -103,13 +146,68 @@ class TrainingSupport(Protocol, metaclass=ABCMeta):
return
class ModelManageableExtras(Protocol, metaclass=ABCMeta):
@runtime_checkable
class ModelManageable(HooksSupport, TrainingSupport, Protocol):
"""
Objects which implement this protocol can be managed by
>>> from comfy.model_management import load_models_gpu
>>> class ModelWrapper(ModelManageable):
>>> ...
>>>
>>> some_model = ModelWrapper()
>>> load_models_gpu([some_model])
The minimum required
"""
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
@property
def current_device(self) -> torch.device:
return torch.device("cpu")
def current_device(self) -> torch.device: ...
def is_clone(self, other: ModelManageableT) -> bool: ...
def clone_has_same_weights(self, clone: ModelManageableT) -> bool: ...
def model_size(self) -> int: ...
def model_patches_to(self, arg: torch.device | torch.dtype): ...
def model_dtype(self) -> torch.dtype: ...
def lowvram_patch_counter(self) -> int: ...
def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False) -> int: ...
def partially_unload(self, device_to: torch.device, memory_to_free: int = 0) -> int: ...
def memory_required(self, input_shape: torch.Size) -> int: ...
def loaded_size(self) -> int: ...
def current_loaded_device(self) -> torch.device: ...
def get_model_object(self, name: str) -> torch.nn.Module: ...
@property
def model_options(self) -> ModelOptions: ...
@model_options.setter
def model_options(self, value): ...
def __del__(self): ...
@property
def parent(self) -> ModelManageableT | None: ...
def detach(self, unpatch_all: bool = True): ...
def clone(self) -> ModelManageableT: ...
class ModelManageableRequired(Protocol, metaclass=ABCMeta):
class ModelManageableStub(HooksSupportStub, TrainingSupportStub, ModelManageable, metaclass=ABCMeta):
"""
The bare minimum that must be implemented to support model management when inheriting from ModelManageable
@ -120,12 +218,11 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta):
:see: ModelManageable
:see: PatchSupport
"""
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
@abstractmethod
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True,
force_patch_weights: bool = False) -> torch.nn.Module:
"""
Called by ModelManageable
@ -155,25 +252,6 @@ class ModelManageableRequired(Protocol, metaclass=ABCMeta):
"""
...
@runtime_checkable
class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol, metaclass=ABCMeta):
"""
Objects which implement this protocol can be managed by
>>> from comfy.model_management import load_models_gpu
>>> class ModelWrapper(ModelManageable):
>>> ...
>>>
>>> some_model = ModelWrapper()
>>> load_models_gpu([some_model])
The minimum required
"""
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
@property
@override
def current_device(self) -> torch.device:
@ -265,6 +343,9 @@ class ModelManageable(ModelManageableRequired, ModelManageableExtras, Protocol,
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
return self.model
def clone(self) -> ModelManageableT:
return copy.copy(self)
@dataclasses.dataclass
class MemoryMeasurements:

View File

@ -230,7 +230,7 @@ class GGUFQuantization:
patch_on_device: bool = False
class ModelPatcher(ModelManageable, TrainingSupport, HooksSupport, PatchSupport):
class ModelPatcher(ModelManageable, PatchSupport):
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size
self.model: BaseModel | torch.nn.Module = model

View File

@ -27,6 +27,7 @@ import os
import random
import struct
import sys
import threading
import warnings
from contextlib import contextmanager
from pathlib import Path
@ -57,6 +58,9 @@ DISABLE_MMAP = args.disable_mmap
logger = logging.getLogger(__name__)
ALWAYS_SAFE_LOAD = False
_lock = threading.RLock()
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
class ModelCheckpoint:
pass
@ -1178,7 +1182,9 @@ class ProgressBar:
def update_absolute(self, value, total=None, preview_image_or_output=None):
if total is not None:
self.total = total
if value > self.total:
if value is None:
return
if self.total is not None and value > self.total:
value = self.total
self.current = value
_progress_bar_update(self.current, self.total, preview_image_or_output, server=self.server, node_id=self.node_id)
@ -1198,31 +1204,39 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
Monkey patches child calls to tqdm, sends progress to the UI,
and yields a watcher object for stall detection.
"""
with _lock:
if hasattr(tqdm, "__patched_by_comfyui__"):
yield getattr(tqdm, "__patched_by_comfyui__")
return
watcher = TqdmWatcher()
setattr(tqdm, "__patched_by_comfyui__", watcher)
_original_init = tqdm.__init__
_original_call = tqdm.__call__
_original_update = tqdm.update
# Create the watcher instance that the patched methods will update
# and that will be yielded to the caller.
watcher = TqdmWatcher()
context = contextvars.copy_context()
try:
# These inner functions are closures; they capture the `watcher` variable
# from the enclosing scope.
def __init(self, *args, **kwargs):
context.run(lambda: _original_init(self, *args, **kwargs))
_original_init(self, *args, **kwargs)
self._progress_bar = context.run(lambda: ProgressBar(self.total))
watcher.tick() # Signal progress on initialization
def __update(self, n=1):
assert self._progress_bar is not None
context.run(lambda: _original_update(self, n))
_original_update(self, n)
context.run(lambda: self._progress_bar.update(n))
watcher.tick() # Signal progress on update
def __call(self, *args, **kwargs):
instance = context.run(lambda: _original_call(self, *args, **kwargs))
instance = _original_call(self, *args, **kwargs)
return instance
tqdm.__init__ = __init
@ -1236,10 +1250,11 @@ def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
tqdm.__init__ = _original_init
tqdm.__call__ = _original_call
tqdm.update = _original_update
delattr(tqdm, "__patched_by_comfyui__")
@contextmanager
def comfy_progress(total: float) -> ProgressBar:
def comfy_progress(total: float) -> Generator[ProgressBar, Any, None]:
ctx = current_execution_context()
if ctx.server.receive_all_progress_notifications:
yield ProgressBar(total)

View File

@ -765,7 +765,7 @@ class DualCFGGuider:
FUNCTION = "get_guider"
CATEGORY = "sampling/custom_sampling/guiders"
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style):
def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style="regular"):
guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))

View File

@ -2,7 +2,9 @@ import torch
from diffusers import HookRegistry
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_management import vram_state, VRAMState
from comfy.model_management_types import HooksSupport, ModelManageable
from comfy.model_patcher import ModelPatcher
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode
@ -117,8 +119,20 @@ class GroupOffload(CustomNode):
RETURN_TYPES = ("MODEL",)
FUNCTION = "execute"
def execute(self, model: ModelPatcher) -> tuple[ModelPatcher,]:
def execute(self, model: ModelManageable | HooksSupport | TransformersManagedModel) -> tuple[ModelPatcher,]:
if isinstance(model, ModelManageable):
model = model.clone()
if isinstance(model, TransformersManagedModel):
apply_group_offloading(
model.model,
model.load_device,
model.offload_device,
use_stream=True,
record_stream=True,
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,),
num_blocks_per_group=1
)
elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable):
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device))
return model,

View File

@ -42,6 +42,7 @@ class LTXVImgToVideo:
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}, "optional": {
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
}}
@ -51,7 +52,7 @@ class LTXVImgToVideo:
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength):
def generate(self, positive, negative, image, vae, width, height, length, batch_size, strength=1.0):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)

View File

@ -9,7 +9,7 @@ from comfy import utils
from comfy.component_model.tensor_types import RGBImageBatch
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
from comfy.model_management import load_models_gpu
from comfy.model_management_types import ModelManageable
from comfy.model_management_types import ModelManageableStub
logger = logging.getLogger(__name__)
try:
@ -22,7 +22,7 @@ except:
pass
class UpscaleModelManageable(ModelManageable):
class UpscaleModelManageable(ModelManageableStub):
def __init__(self, model_descriptor: ImageModelDescriptor, ckpt_name: str):
self.ckpt_name = ckpt_name
self.model_descriptor = model_descriptor

View File

@ -182,7 +182,7 @@ async def test_huggingface_alternate_filenames_in_combo():
)
# 3. Get the list of files as the UI would
filename_list = get_filename_list_with_downloadable("checkpoints", known_files=[known_file])
filename_list = get_filename_list_with_downloadable("__xxx___", known_files=[known_file])
# 4. Assert that both the main and alternate filenames are present
assert main_filename in filename_list

View File

@ -1,7 +1,7 @@
{
"1": {
"inputs": {
"ckpt_name": "microsoft/Phi-3-mini-4k-instruct",
"ckpt_name": "microsoft/Phi-4-mini-instruct",
"subfolder": ""
},
"class_type": "TransformersLoader",
@ -33,7 +33,7 @@
"4": {
"inputs": {
"prompt": "What comes after apple?",
"chat_template": "phi-3",
"chat_template": "default",
"model": [
"1",
0