mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
add video support for language models, improve torch compile for transformers models, make it easier to load videos, fix progess messages from language nodes
This commit is contained in:
parent
81ea9726b6
commit
b149031748
@ -1161,8 +1161,12 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
await self.send_image(data, sid=sid)
|
await self.send_image(data, sid=sid)
|
||||||
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
|
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
|
||||||
# data is (preview_image, metadata)
|
# data is (preview_image, metadata)
|
||||||
|
|
||||||
data: PreviewImageWithMetadataMessage
|
data: PreviewImageWithMetadataMessage
|
||||||
preview_image, metadata = data
|
preview_image, metadata = data
|
||||||
|
if isinstance(preview_image, dict):
|
||||||
|
# todo: this has to be fixed from transformers loader for previewing tokens in real time
|
||||||
|
return
|
||||||
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
|
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
|
||||||
elif isinstance(data, (bytes, bytearray)):
|
elif isinstance(data, (bytes, bytearray)):
|
||||||
await self.send_bytes(event, data, sid)
|
await self.send_bytes(event, data, sid)
|
||||||
|
|||||||
@ -92,14 +92,13 @@ class LanguageModel(Protocol):
|
|||||||
|
|
||||||
def generate(self, tokens: TOKENS_TYPE = None,
|
def generate(self, tokens: TOKENS_TYPE = None,
|
||||||
max_new_tokens: int = 512,
|
max_new_tokens: int = 512,
|
||||||
repetition_penalty: float = 0.0,
|
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs) -> str:
|
**kwargs) -> str:
|
||||||
...
|
...
|
||||||
|
|
||||||
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult:
|
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, videos: list[torch.Tensor] | None, chat_template: str | None = None) -> ProcessorResult:
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixi
|
|||||||
from huggingface_hub import hf_api
|
from huggingface_hub import hf_api
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
|
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, AutoModelForImageTextToText
|
||||||
|
|
||||||
from .chat_templates import KNOWN_CHAT_TEMPLATES
|
from .chat_templates import KNOWN_CHAT_TEMPLATES
|
||||||
from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \
|
from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \
|
||||||
@ -122,7 +122,7 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
for i, kwargs_to_try in enumerate(kwargses_to_try):
|
for i, kwargs_to_try in enumerate(kwargses_to_try):
|
||||||
try:
|
try:
|
||||||
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
||||||
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
model = AutoModelForImageTextToText.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||||
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||||
elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
@ -176,7 +176,6 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
|
|
||||||
def generate(self, tokens: TOKENS_TYPE = None,
|
def generate(self, tokens: TOKENS_TYPE = None,
|
||||||
max_new_tokens: int = 512,
|
max_new_tokens: int = 512,
|
||||||
repetition_penalty: float = 0.0,
|
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||||
*args,
|
*args,
|
||||||
@ -257,7 +256,6 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
**inputs,
|
**inputs,
|
||||||
streamer=text_streamer if num_beams <= 1 else None,
|
streamer=text_streamer if num_beams <= 1 else None,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
|
|
||||||
**generate_kwargs
|
**generate_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -364,7 +362,7 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"):
|
if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"):
|
||||||
processor.image_processor.do_rescale = False
|
processor.image_processor.do_rescale = False
|
||||||
|
|
||||||
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult:
|
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, videos: list[torch.Tensor] | None = None, chat_template: str | None = None) -> ProcessorResult:
|
||||||
tokenizer = self.processor if self.processor is not None else self.tokenizer
|
tokenizer = self.processor if self.processor is not None else self.tokenizer
|
||||||
assert tokenizer is not None
|
assert tokenizer is not None
|
||||||
assert hasattr(tokenizer, "decode")
|
assert hasattr(tokenizer, "decode")
|
||||||
@ -391,16 +389,18 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
messages: LanguagePrompt
|
messages: LanguagePrompt
|
||||||
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
|
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
|
||||||
messages = prompt
|
messages = prompt
|
||||||
elif "content[" in chat_template:
|
elif images is not None and len(images) > 0 or videos is not None and len(videos) > 0:
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user",
|
{"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": prompt
|
"text": prompt if isinstance(prompt, str) else ""
|
||||||
}
|
}
|
||||||
] + [
|
] + [
|
||||||
{"type": "image"} for _ in range(len(images))
|
{"type": "image"} for _ in range(len(images))
|
||||||
|
] + [
|
||||||
|
{"type": "video"} for _ in range(len(videos))
|
||||||
]
|
]
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -409,6 +409,7 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
prompt = tokenizer.apply_chat_template(messages, chat_template=chat_template, add_generation_prompt=True, tokenize=False)
|
prompt = tokenizer.apply_chat_template(messages, chat_template=chat_template, add_generation_prompt=True, tokenize=False)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Could not apply chat template", exc_info=exc)
|
logger.debug("Could not apply chat template", exc_info=exc)
|
||||||
@ -422,7 +423,14 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
# convert tuple to list from images.unbind() for paligemma workaround
|
# convert tuple to list from images.unbind() for paligemma workaround
|
||||||
image_tensor_list = list(images.unbind()) if images is not None and len(images) > 0 else None
|
image_tensor_list = list(images.unbind()) if images is not None and len(images) > 0 else None
|
||||||
try:
|
try:
|
||||||
batch_feature: BatchFeature = self.processor(text=[prompt], images=image_tensor_list, return_tensors="pt", padding=True)
|
batch_feature: BatchFeature = self.processor(
|
||||||
|
text=[prompt],
|
||||||
|
images=image_tensor_list,
|
||||||
|
videos=None if videos is not None and len(videos) == 0 or (hasattr(videos, "shape") and videos.shape[0]) == 0 else videos,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
input_data_format="channels_last" # Ensure this is set for Qwen
|
||||||
|
)
|
||||||
except TypeError as exc_info:
|
except TypeError as exc_info:
|
||||||
logger.warning(f"Exception while trying to run processor. Your transformers package is version {transformers.__version__} and may need to be updated")
|
logger.warning(f"Exception while trying to run processor. Your transformers package is version {transformers.__version__} and may need to be updated")
|
||||||
raise exc_info
|
raise exc_info
|
||||||
|
|||||||
@ -65,7 +65,7 @@ class ProgressHandler(ABC):
|
|||||||
max_value: float,
|
max_value: float,
|
||||||
state: NodeProgressState,
|
state: NodeProgressState,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
image: PreviewImageTuple | None = None,
|
image: PreviewImageTuple | dict | None = None,
|
||||||
):
|
):
|
||||||
"""Called when a node's progress is updated"""
|
"""Called when a node's progress is updated"""
|
||||||
pass
|
pass
|
||||||
@ -210,7 +210,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
max_value: float,
|
max_value: float,
|
||||||
state: NodeProgressState,
|
state: NodeProgressState,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
image: PreviewImageTuple | None = None,
|
image: PreviewImageTuple | dict | None = None,
|
||||||
):
|
):
|
||||||
# Send progress state of all nodes
|
# Send progress state of all nodes
|
||||||
if self.registry:
|
if self.registry:
|
||||||
@ -233,6 +233,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
),
|
),
|
||||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||||
}
|
}
|
||||||
|
# todo: image can be a dict because of transformers loader, do we have to deal with this specially? things just work
|
||||||
message: PreviewImageWithMetadataMessage = (image, metadata)
|
message: PreviewImageWithMetadataMessage = (image, metadata)
|
||||||
self.server_instance.send_sync(
|
self.server_instance.send_sync(
|
||||||
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
|
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
|
||||||
@ -300,7 +301,7 @@ class ProgressRegistry(AbstractProgressRegistry):
|
|||||||
handler.start_handler(node_id, entry, self.prompt_id)
|
handler.start_handler(node_id, entry, self.prompt_id)
|
||||||
|
|
||||||
def update_progress(
|
def update_progress(
|
||||||
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None
|
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | dict | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update progress for a node"""
|
"""Update progress for a node"""
|
||||||
entry = self.ensure_entry(node_id)
|
entry = self.ensure_entry(node_id)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import operator
|
import operator
|
||||||
import os.path
|
import os.path
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
@ -332,6 +333,7 @@ class OneShotInstructTokenize(CustomNode):
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"images": ("IMAGE", {}),
|
"images": ("IMAGE", {}),
|
||||||
|
"videos": ("VIDEO", {}),
|
||||||
"system_prompt": ("STRING", {"multiline": True, "default": ""})
|
"system_prompt": ("STRING", {"multiline": True, "default": ""})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -340,9 +342,8 @@ class OneShotInstructTokenize(CustomNode):
|
|||||||
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: Optional[str] = _AUTO_CHAT_TEMPLATE, system_prompt: str = "") -> ValidatedNodeResult:
|
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, videos: list | object = None, chat_template: Optional[str] = _AUTO_CHAT_TEMPLATE, system_prompt: str = "") -> ValidatedNodeResult:
|
||||||
if chat_template == _AUTO_CHAT_TEMPLATE:
|
if chat_template == _AUTO_CHAT_TEMPLATE:
|
||||||
# use an exact match
|
|
||||||
model_name = os.path.basename(model.repo_id)
|
model_name = os.path.basename(model.repo_id)
|
||||||
if model_name in KNOWN_CHAT_TEMPLATES:
|
if model_name in KNOWN_CHAT_TEMPLATES:
|
||||||
chat_template = KNOWN_CHAT_TEMPLATES[model_name]
|
chat_template = KNOWN_CHAT_TEMPLATES[model_name]
|
||||||
@ -351,22 +352,43 @@ class OneShotInstructTokenize(CustomNode):
|
|||||||
elif chat_template is not None:
|
elif chat_template is not None:
|
||||||
chat_template = KNOWN_CHAT_TEMPLATES[chat_template]
|
chat_template = KNOWN_CHAT_TEMPLATES[chat_template]
|
||||||
|
|
||||||
|
video_tensors = []
|
||||||
|
if videos is not None:
|
||||||
|
if not isinstance(videos, list):
|
||||||
|
videos_list = [videos]
|
||||||
|
else:
|
||||||
|
videos_list = videos
|
||||||
|
|
||||||
|
for vid in videos_list:
|
||||||
|
if hasattr(vid, "get_components"):
|
||||||
|
components = vid.get_components()
|
||||||
|
video_tensors.append(components.images)
|
||||||
|
elif isinstance(vid, torch.Tensor):
|
||||||
|
video_tensors.append(vid)
|
||||||
|
|
||||||
messages: LanguagePrompt | str
|
messages: LanguagePrompt | str
|
||||||
if system_prompt != "":
|
|
||||||
messages: LanguagePrompt = [
|
has_images = images is not None and len(images) > 0
|
||||||
{"role": "system",
|
has_videos = len(video_tensors) > 0
|
||||||
"content": system_prompt},
|
|
||||||
{"role": "user",
|
if system_prompt != "" or has_images or has_videos:
|
||||||
"content": [
|
user_content = [{"type": "text", "text": prompt}]
|
||||||
{"type": "text",
|
if has_images:
|
||||||
"text": prompt}
|
user_content += [{"type": "image"} for _ in range(len(images))]
|
||||||
] + [
|
|
||||||
{"type": "image"} for _ in range(len(images) if images is not None else 0)
|
if has_videos:
|
||||||
], }
|
user_content += [{"type": "video"} for _ in range(len(video_tensors))]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": user_content}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if system_prompt.strip() != "":
|
||||||
|
messages.insert(0, {"role": "system", "content": system_prompt})
|
||||||
else:
|
else:
|
||||||
messages: str = prompt
|
messages = prompt
|
||||||
return model.tokenize(messages, images, chat_template),
|
|
||||||
|
return model.tokenize(messages, images, video_tensors, chat_template),
|
||||||
|
|
||||||
|
|
||||||
class TransformersGenerate(CustomNode):
|
class TransformersGenerate(CustomNode):
|
||||||
@ -376,8 +398,7 @@ class TransformersGenerate(CustomNode):
|
|||||||
"required": {
|
"required": {
|
||||||
"model": ("MODEL",),
|
"model": ("MODEL",),
|
||||||
"tokens": (TOKENS_TYPE_NAME, {}),
|
"tokens": (TOKENS_TYPE_NAME, {}),
|
||||||
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
"max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 0xffffffff}),
|
||||||
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
|
||||||
"seed": Seed,
|
"seed": Seed,
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
@ -393,11 +414,10 @@ class TransformersGenerate(CustomNode):
|
|||||||
model: Optional[LanguageModel] = None,
|
model: Optional[LanguageModel] = None,
|
||||||
tokens: TOKENS_TYPE = None,
|
tokens: TOKENS_TYPE = None,
|
||||||
max_new_tokens: int = 512,
|
max_new_tokens: int = 512,
|
||||||
repetition_penalty: float = 0.0,
|
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||||
):
|
):
|
||||||
return model.generate(tokens, max_new_tokens, repetition_penalty, seed, sampler),
|
return model.generate(tokens, max_new_tokens, seed, sampler),
|
||||||
|
|
||||||
|
|
||||||
class PreviewString(CustomNode):
|
class PreviewString(CustomNode):
|
||||||
@ -427,7 +447,7 @@ class SaveString(CustomNode):
|
|||||||
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"extension": ("STRING", {"default": ".json"})
|
"extension": ("STRING", {"default": ".txt"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -439,17 +459,53 @@ class SaveString(CustomNode):
|
|||||||
def get_save_path(self, filename_prefix) -> SaveImagePathTuple:
|
def get_save_path(self, filename_prefix) -> SaveImagePathTuple:
|
||||||
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
|
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
|
||||||
|
|
||||||
def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"):
|
def execute(self, value: str | list[str] = "", filename_prefix: str = "ComfyUI", extension: str = ".txt"):
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix)
|
full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix)
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
value = [value]
|
value = [value]
|
||||||
|
|
||||||
for i, value_i in enumerate(value):
|
for i, value_i in enumerate(value):
|
||||||
# roughly matches the behavior of save image, but does not support batch numbers
|
# roughly matches the behavior of save image, but does not support batch numbers
|
||||||
with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}_{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}_{extension}"), "wt+") as f:
|
with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}{extension}"), "wt+") as f:
|
||||||
f.write(value_i)
|
f.write(value_i)
|
||||||
return {"ui": {"string": value}}
|
return {"ui": {"string": value}}
|
||||||
|
|
||||||
|
|
||||||
|
class OmitThink(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("STRING", {"forceInput": True}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "strings"
|
||||||
|
FUNCTION = "execute"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
|
||||||
|
def execute(self, value: str | list[str] = "") -> tuple[list[str]]:
|
||||||
|
pattern_explicit = r"<think>.*?</think>"
|
||||||
|
pattern_missing_start = r"^.*?</think>"
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
values = [value]
|
||||||
|
else:
|
||||||
|
values = value
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for value in values:
|
||||||
|
if "<think>" in value:
|
||||||
|
cleaned_text = re.sub(pattern_explicit, "", value, flags=re.DOTALL)
|
||||||
|
elif "</think>" in value:
|
||||||
|
cleaned_text = re.sub(pattern_missing_start, "", value, flags=re.DOTALL)
|
||||||
|
else:
|
||||||
|
cleaned_text = value
|
||||||
|
result.append(cleaned_text.strip())
|
||||||
|
|
||||||
|
return result,
|
||||||
|
|
||||||
|
|
||||||
export_custom_nodes()
|
export_custom_nodes()
|
||||||
export_package_as_web_directory("comfy_extras.language_web")
|
export_package_as_web_directory("comfy_extras.language_web")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -16,6 +15,7 @@ from typing import Sequence, Optional, TypedDict, List, Literal, Tuple, Any, Dic
|
|||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import av
|
||||||
import certifi
|
import certifi
|
||||||
import cv2
|
import cv2
|
||||||
import fsspec
|
import fsspec
|
||||||
@ -872,4 +872,162 @@ class ImageRequestParameter(CustomNode):
|
|||||||
return ImageMaskTuple(output_images_batched, output_masks_batched)
|
return ImageMaskTuple(output_images_batched, output_masks_batched)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageFromURL(ImageRequestParameter):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("STRING", {"default": ""})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"default_if_empty": ("IMAGE",),
|
||||||
|
"alpha_is_transparency": ("BOOLEAN", {"default": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple:
|
||||||
|
return super().execute(value, default_if_empty, alpha_is_transparency, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoRequestParameter(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("STRING", {"default": ""})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
**_open_api_common_schema,
|
||||||
|
"default_if_empty": ("VIDEO",),
|
||||||
|
"frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1, "tooltip": "0 for no limit, otherwise stop loading after N frames"}),
|
||||||
|
"skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}),
|
||||||
|
"select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT")
|
||||||
|
RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps")
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "api/openapi"
|
||||||
|
|
||||||
|
def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs) -> tuple[Tensor, Tensor, int, float]:
|
||||||
|
if value.strip() == "":
|
||||||
|
if default_if_empty is None:
|
||||||
|
return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0)
|
||||||
|
|
||||||
|
frames = default_if_empty.shape[0] if isinstance(default_if_empty, torch.Tensor) else 0
|
||||||
|
height = default_if_empty.shape[1] if frames > 0 else 1
|
||||||
|
width = default_if_empty.shape[2] if frames > 0 else 1
|
||||||
|
|
||||||
|
default_mask = torch.ones((frames, height, width), dtype=torch.float32)
|
||||||
|
return (default_if_empty, default_mask, frames, 0.0)
|
||||||
|
|
||||||
|
output_videos = []
|
||||||
|
output_masks = []
|
||||||
|
total_frames_loaded = 0
|
||||||
|
fps = 0.0
|
||||||
|
|
||||||
|
fsspec_kwargs = {}
|
||||||
|
if value.startswith('http'):
|
||||||
|
fsspec_kwargs.update({
|
||||||
|
"headers": {
|
||||||
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36'
|
||||||
|
},
|
||||||
|
'get_client': get_client
|
||||||
|
})
|
||||||
|
|
||||||
|
with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files:
|
||||||
|
for f in files:
|
||||||
|
try:
|
||||||
|
container = av.open(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"VideoRequestParameter: Failed to open video container for {value}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(container.streams.video) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
stream = container.streams.video[0]
|
||||||
|
stream.thread_type = "AUTO"
|
||||||
|
|
||||||
|
if fps == 0.0:
|
||||||
|
fps = float(stream.average_rate)
|
||||||
|
|
||||||
|
frames_list = []
|
||||||
|
masks_list = []
|
||||||
|
frames_processed = 0
|
||||||
|
frames_kept = 0
|
||||||
|
|
||||||
|
for frame in container.decode(stream):
|
||||||
|
frames_processed += 1
|
||||||
|
|
||||||
|
if frames_processed <= skip_first_frames:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (frames_processed - skip_first_frames - 1) % select_every_nth != 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
np_frame = frame.to_ndarray(format="rgba")
|
||||||
|
tensor_img = torch.from_numpy(np_frame[..., :3]).float() / 255.0
|
||||||
|
frames_list.append(tensor_img)
|
||||||
|
tensor_mask = torch.from_numpy(np_frame[..., 3]).float() / 255.0
|
||||||
|
masks_list.append(tensor_mask)
|
||||||
|
|
||||||
|
frames_kept += 1
|
||||||
|
|
||||||
|
if frame_load_cap > 0 and frames_kept >= frame_load_cap:
|
||||||
|
break
|
||||||
|
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
if frames_list:
|
||||||
|
video_tensor = torch.stack(frames_list)
|
||||||
|
mask_tensor = torch.stack(masks_list)
|
||||||
|
|
||||||
|
output_videos.append(video_tensor)
|
||||||
|
output_masks.append(mask_tensor)
|
||||||
|
total_frames_loaded += frames_kept
|
||||||
|
|
||||||
|
if not output_videos:
|
||||||
|
if default_if_empty is not None:
|
||||||
|
frames = default_if_empty.shape[0]
|
||||||
|
height = default_if_empty.shape[1]
|
||||||
|
width = default_if_empty.shape[2]
|
||||||
|
return (default_if_empty, torch.ones((frames, height, width), dtype=torch.float32), frames, 0.0)
|
||||||
|
return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
final_video = torch.cat(output_videos, dim=0)
|
||||||
|
final_mask = torch.cat(output_masks, dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning("VideoRequestParameter: Video resolutions mismatch in input list. Returning only the first video.")
|
||||||
|
final_video = output_videos[0]
|
||||||
|
final_mask = output_masks[0]
|
||||||
|
total_frames_loaded = final_video.shape[0]
|
||||||
|
|
||||||
|
return (final_video, final_mask, total_frames_loaded, fps)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadVideoFromURL(VideoRequestParameter):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("STRING", {"default": ""})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"default_if_empty": ("VIDEO",),
|
||||||
|
"frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1}),
|
||||||
|
"skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}),
|
||||||
|
"select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT")
|
||||||
|
RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps")
|
||||||
|
|
||||||
|
def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs):
|
||||||
|
return super().execute(value, default_if_empty, frame_load_cap, skip_first_frames, select_every_nth, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
export_custom_nodes()
|
export_custom_nodes()
|
||||||
|
|||||||
@ -46,6 +46,8 @@ def write_atomic(
|
|||||||
|
|
||||||
|
|
||||||
torch._inductor.codecache.write_atomic = write_atomic
|
torch._inductor.codecache.write_atomic = write_atomic
|
||||||
|
|
||||||
|
|
||||||
# torch._inductor.utils.is_big_gpu = lambda *args: True
|
# torch._inductor.utils.is_big_gpu = lambda *args: True
|
||||||
|
|
||||||
|
|
||||||
@ -98,7 +100,18 @@ class TorchCompileModel(CustomNode):
|
|||||||
"make_refittable": True,
|
"make_refittable": True,
|
||||||
}
|
}
|
||||||
del compile_kwargs["mode"]
|
del compile_kwargs["mode"]
|
||||||
if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)):
|
if isinstance(model, TransformersManagedModel):
|
||||||
|
to_return = model.clone()
|
||||||
|
model = to_return.model
|
||||||
|
|
||||||
|
model_management.unload_all_models()
|
||||||
|
model.to(device=model_management.get_torch_device())
|
||||||
|
res = torch.compile(model=model, **compile_kwargs),
|
||||||
|
model.to(device=model_management.unet_offload_device())
|
||||||
|
|
||||||
|
to_return.add_object_patch("model", res)
|
||||||
|
return to_return,
|
||||||
|
elif isinstance(model, (ModelPatcher, VAE)):
|
||||||
to_return = model.clone()
|
to_return = model.clone()
|
||||||
object_patches = [p.strip() for p in object_patch.split(",")]
|
object_patches = [p.strip() for p in object_patch.split(",")]
|
||||||
patcher: ModelPatcher
|
patcher: ModelPatcher
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user