mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 20:00:17 +08:00
Better support for transformers t5
This commit is contained in:
parent
8a1557f750
commit
e7682ced56
@ -159,12 +159,13 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
filename, chat_template = candidate_chat_templates[0]
|
filename, chat_template = candidate_chat_templates[0]
|
||||||
logging.debug(f"Selected chat template filename={filename} for {self.model.name_or_path}")
|
logging.debug(f"Selected chat template filename={filename} for {self.model.name_or_path}")
|
||||||
try:
|
try:
|
||||||
# todo: this should come from node inputs
|
if hasattr(tokenizer, "apply_chat_template"):
|
||||||
prompt = tokenizer.apply_chat_template([
|
# todo: this should come from node inputs
|
||||||
{"role": "user", "content": prompt},
|
prompt = tokenizer.apply_chat_template([
|
||||||
], chat_template=chat_template, add_generation_prompt=True, tokenize=False)
|
{"role": "user", "content": prompt},
|
||||||
|
], chat_template=chat_template, add_generation_prompt=True, tokenize=False)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logging.error("Could not apply chat template", exc_info=exc)
|
logging.debug("Could not apply chat template", exc_info=exc)
|
||||||
|
|
||||||
if self.processor is None:
|
if self.processor is None:
|
||||||
batch_encoding = tokenizer(prompt, return_tensors="pt").to(device=self.load_device)
|
batch_encoding = tokenizer(prompt, return_tensors="pt").to(device=self.load_device)
|
||||||
|
|||||||
@ -167,6 +167,13 @@ KNOWN_CHECKPOINTS = [
|
|||||||
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
|
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
|
||||||
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
|
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
|
||||||
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
|
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
|
||||||
|
HuggingFile("SG161222/Realistic_Vision_V6.0_B1_noVAE","Realistic_Vision_V6.0_NV_B1_fp16.safetensors"),
|
||||||
|
HuggingFile("SG161222/Realistic_Vision_V5.1_noVAE","Realistic_Vision_V5.1_fp16-no-ema.safetensors"),
|
||||||
|
CivitFile(4384, 128713, filename="dreamshaper_8.safetensors"),
|
||||||
|
CivitFile(7371, 425083, filename="revAnimated_v2Rebirth.safetensors"),
|
||||||
|
CivitFile(4468, 57618, filename="counterfeitV30_v30.safetensors"),
|
||||||
|
CivitFile(241415, 272376, filename="picxReal_10.safetensors"),
|
||||||
|
CivitFile(23900, 95489, filename="anyloraCheckpoint_bakedvaeBlessedFp16.safetensors"),
|
||||||
]
|
]
|
||||||
|
|
||||||
KNOWN_UNCLIP_CHECKPOINTS = [
|
KNOWN_UNCLIP_CHECKPOINTS = [
|
||||||
@ -195,8 +202,9 @@ KNOWN_CLIP_VISION_MODELS = [
|
|||||||
KNOWN_LORAS = [
|
KNOWN_LORAS = [
|
||||||
CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"),
|
CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"),
|
||||||
CivitFile(model_id=324815, model_version_id=364137, filename="blur_control_xl_v1.safetensors"),
|
CivitFile(model_id=324815, model_version_id=364137, filename="blur_control_xl_v1.safetensors"),
|
||||||
|
CivitFile(model_id=47085, model_version_id=55199, filename="GoodHands-beta2.safetensors"),
|
||||||
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-12steps-CFG-lora.safetensors"),
|
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-12steps-CFG-lora.safetensors"),
|
||||||
HuggingFile("ByteDance/Hyper-SD", "Hyper-SD15-12steps-CFG-lora.safetensors")
|
HuggingFile("ByteDance/Hyper-SD", "Hyper-SD15-12steps-CFG-lora.safetensors"),
|
||||||
]
|
]
|
||||||
|
|
||||||
KNOWN_CONTROLNETS = [
|
KNOWN_CONTROLNETS = [
|
||||||
|
|||||||
@ -44,9 +44,19 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
exported_nodes = ExportedNodes()
|
exported_nodes = ExportedNodes()
|
||||||
timings = []
|
timings = []
|
||||||
exceptions = []
|
exceptions = []
|
||||||
if _import_nodes_in_module(exported_nodes, module):
|
with tracer.start_as_current_span("Load Node") as span:
|
||||||
pass
|
time_before = time.perf_counter()
|
||||||
else:
|
try:
|
||||||
|
module_decl = _import_nodes_in_module(exported_nodes, module)
|
||||||
|
full_name = module.__name__
|
||||||
|
span.set_attribute("full_name", full_name)
|
||||||
|
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
|
||||||
|
except Exception as exc:
|
||||||
|
logging.error(f"{full_name} import failed", exc_info=exc)
|
||||||
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
|
span.record_exception(exc)
|
||||||
|
exceptions.append(exc)
|
||||||
|
if module_decl is None or not module_decl:
|
||||||
# Iterate through all the submodules
|
# Iterate through all the submodules
|
||||||
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
|
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
|
||||||
span: Span
|
span: Span
|
||||||
@ -55,6 +65,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
time_before = time.perf_counter()
|
time_before = time.perf_counter()
|
||||||
success = True
|
success = True
|
||||||
span.set_attribute("full_name", full_name)
|
span.set_attribute("full_name", full_name)
|
||||||
|
new_nodes = ExportedNodes()
|
||||||
if full_name.endswith(".disabled"):
|
if full_name.endswith(".disabled"):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
@ -75,11 +86,11 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
exceptions.append(x)
|
exceptions.append(x)
|
||||||
span.set_status(Status(StatusCode.ERROR))
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
span.record_exception(x)
|
span.record_exception(x)
|
||||||
timings.append((time.perf_counter() - time_before, full_name, success))
|
timings.append((time.perf_counter() - time_before, full_name, success, new_nodes))
|
||||||
|
|
||||||
if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings):
|
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
|
||||||
for (duration, module_name, success) in sorted(timings):
|
for (duration, module_name, success, new_nodes) in sorted(timings):
|
||||||
print(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name}")
|
logging.info(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
|
||||||
if raise_on_failure and len(exceptions) > 0:
|
if raise_on_failure and len(exceptions) > 0:
|
||||||
try:
|
try:
|
||||||
raise ExceptionGroup("Node import failed", exceptions)
|
raise ExceptionGroup("Node import failed", exceptions)
|
||||||
|
|||||||
0
comfy/t5_tokenizer/__init__.py
Normal file
0
comfy/t5_tokenizer/__init__.py
Normal file
@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, List, Callable, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||||
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
|
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
|
||||||
LlavaNextForConditionalGeneration, LlavaNextProcessor
|
LlavaNextForConditionalGeneration, LlavaNextProcessor, T5EncoderModel, AutoModel
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||||
@ -245,8 +245,9 @@ class TransformersLoader(CustomNode):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs)
|
model = AutoModel.from_pretrained(**from_pretrained_kwargs)
|
||||||
except:
|
except:
|
||||||
|
# not yet supported by automodel
|
||||||
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
|
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
|
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
|
||||||
|
|||||||
7
setup.py
7
setup.py
@ -185,7 +185,12 @@ def dependencies(force_nightly: bool = False) -> List[str]:
|
|||||||
return _dependencies
|
return _dependencies
|
||||||
|
|
||||||
|
|
||||||
package_data = ['sd1_tokenizer/*', '**/*.json', '**/*.yaml']
|
package_data = [
|
||||||
|
'sd1_tokenizer/*',
|
||||||
|
't5_tokenizer/*',
|
||||||
|
'**/*.json',
|
||||||
|
'**/*.yaml',
|
||||||
|
]
|
||||||
if not is_editable:
|
if not is_editable:
|
||||||
package_data.append('comfy/web/**/*')
|
package_data.append('comfy/web/**/*')
|
||||||
dev_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-dev.txt")).readlines()
|
dev_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-dev.txt")).readlines()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user