add video support for language models, improve torch compile for transformers models, make it easier to load videos, fix progess messages from language nodes

This commit is contained in:
doctorpangloss 2025-12-03 13:45:01 -08:00
parent 81ea9726b6
commit b149031748
7 changed files with 276 additions and 37 deletions

View File

@ -1161,8 +1161,12 @@ class PromptServer(ExecutorToClientProgress):
await self.send_image(data, sid=sid) await self.send_image(data, sid=sid)
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA: elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
# data is (preview_image, metadata) # data is (preview_image, metadata)
data: PreviewImageWithMetadataMessage data: PreviewImageWithMetadataMessage
preview_image, metadata = data preview_image, metadata = data
if isinstance(preview_image, dict):
# todo: this has to be fixed from transformers loader for previewing tokens in real time
return
await self.send_image_with_metadata(preview_image, metadata, sid=sid) await self.send_image_with_metadata(preview_image, metadata, sid=sid)
elif isinstance(data, (bytes, bytearray)): elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid) await self.send_bytes(event, data, sid)

View File

@ -92,14 +92,13 @@ class LanguageModel(Protocol):
def generate(self, tokens: TOKENS_TYPE = None, def generate(self, tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512, max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0, seed: int = 0,
sampler: Optional[GENERATION_KWARGS_TYPE] = None, sampler: Optional[GENERATION_KWARGS_TYPE] = None,
*args, *args,
**kwargs) -> str: **kwargs) -> str:
... ...
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult: def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, videos: list[torch.Tensor] | None, chat_template: str | None = None) -> ProcessorResult:
... ...
@property @property

View File

@ -19,7 +19,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixi
from huggingface_hub import hf_api from huggingface_hub import hf_api
from huggingface_hub.file_download import hf_hub_download from huggingface_hub.file_download import hf_hub_download
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, AutoModelForImageTextToText
from .chat_templates import KNOWN_CHAT_TEMPLATES from .chat_templates import KNOWN_CHAT_TEMPLATES
from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \ from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \
@ -122,7 +122,7 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
for i, kwargs_to_try in enumerate(kwargses_to_try): for i, kwargs_to_try in enumerate(kwargses_to_try):
try: try:
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES: if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) model = AutoModelForImageTextToText.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
@ -176,7 +176,6 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
def generate(self, tokens: TOKENS_TYPE = None, def generate(self, tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512, max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0, seed: int = 0,
sampler: Optional[GENERATION_KWARGS_TYPE] = None, sampler: Optional[GENERATION_KWARGS_TYPE] = None,
*args, *args,
@ -257,7 +256,6 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
**inputs, **inputs,
streamer=text_streamer if num_beams <= 1 else None, streamer=text_streamer if num_beams <= 1 else None,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
**generate_kwargs **generate_kwargs
) )
@ -364,7 +362,7 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"): if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"):
processor.image_processor.do_rescale = False processor.image_processor.do_rescale = False
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult: def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, videos: list[torch.Tensor] | None = None, chat_template: str | None = None) -> ProcessorResult:
tokenizer = self.processor if self.processor is not None else self.tokenizer tokenizer = self.processor if self.processor is not None else self.tokenizer
assert tokenizer is not None assert tokenizer is not None
assert hasattr(tokenizer, "decode") assert hasattr(tokenizer, "decode")
@ -391,16 +389,18 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
messages: LanguagePrompt messages: LanguagePrompt
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt messages = prompt
elif "content[" in chat_template: elif images is not None and len(images) > 0 or videos is not None and len(videos) > 0:
messages = [ messages = [
{"role": "user", {"role": "user",
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": prompt "text": prompt if isinstance(prompt, str) else ""
} }
] + [ ] + [
{"type": "image"} for _ in range(len(images)) {"type": "image"} for _ in range(len(images))
] + [
{"type": "video"} for _ in range(len(videos))
] ]
} }
@ -409,6 +409,7 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
messages = [ messages = [
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ]
prompt = tokenizer.apply_chat_template(messages, chat_template=chat_template, add_generation_prompt=True, tokenize=False) prompt = tokenizer.apply_chat_template(messages, chat_template=chat_template, add_generation_prompt=True, tokenize=False)
except Exception as exc: except Exception as exc:
logger.debug("Could not apply chat template", exc_info=exc) logger.debug("Could not apply chat template", exc_info=exc)
@ -422,7 +423,14 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
# convert tuple to list from images.unbind() for paligemma workaround # convert tuple to list from images.unbind() for paligemma workaround
image_tensor_list = list(images.unbind()) if images is not None and len(images) > 0 else None image_tensor_list = list(images.unbind()) if images is not None and len(images) > 0 else None
try: try:
batch_feature: BatchFeature = self.processor(text=[prompt], images=image_tensor_list, return_tensors="pt", padding=True) batch_feature: BatchFeature = self.processor(
text=[prompt],
images=image_tensor_list,
videos=None if videos is not None and len(videos) == 0 or (hasattr(videos, "shape") and videos.shape[0]) == 0 else videos,
return_tensors="pt",
padding=True,
input_data_format="channels_last" # Ensure this is set for Qwen
)
except TypeError as exc_info: except TypeError as exc_info:
logger.warning(f"Exception while trying to run processor. Your transformers package is version {transformers.__version__} and may need to be updated") logger.warning(f"Exception while trying to run processor. Your transformers package is version {transformers.__version__} and may need to be updated")
raise exc_info raise exc_info

View File

@ -65,7 +65,7 @@ class ProgressHandler(ABC):
max_value: float, max_value: float,
state: NodeProgressState, state: NodeProgressState,
prompt_id: str, prompt_id: str,
image: PreviewImageTuple | None = None, image: PreviewImageTuple | dict | None = None,
): ):
"""Called when a node's progress is updated""" """Called when a node's progress is updated"""
pass pass
@ -210,7 +210,7 @@ class WebUIProgressHandler(ProgressHandler):
max_value: float, max_value: float,
state: NodeProgressState, state: NodeProgressState,
prompt_id: str, prompt_id: str,
image: PreviewImageTuple | None = None, image: PreviewImageTuple | dict | None = None,
): ):
# Send progress state of all nodes # Send progress state of all nodes
if self.registry: if self.registry:
@ -233,6 +233,7 @@ class WebUIProgressHandler(ProgressHandler):
), ),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
} }
# todo: image can be a dict because of transformers loader, do we have to deal with this specially? things just work
message: PreviewImageWithMetadataMessage = (image, metadata) message: PreviewImageWithMetadataMessage = (image, metadata)
self.server_instance.send_sync( self.server_instance.send_sync(
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
@ -300,7 +301,7 @@ class ProgressRegistry(AbstractProgressRegistry):
handler.start_handler(node_id, entry, self.prompt_id) handler.start_handler(node_id, entry, self.prompt_id)
def update_progress( def update_progress(
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | dict | None = None
) -> None: ) -> None:
"""Update progress for a node""" """Update progress for a node"""
entry = self.ensure_entry(node_id) entry = self.ensure_entry(node_id)

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import operator import operator
import os.path import os.path
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import reduce from functools import reduce
from typing import Optional, List from typing import Optional, List
@ -332,6 +333,7 @@ class OneShotInstructTokenize(CustomNode):
}, },
"optional": { "optional": {
"images": ("IMAGE", {}), "images": ("IMAGE", {}),
"videos": ("VIDEO", {}),
"system_prompt": ("STRING", {"multiline": True, "default": ""}) "system_prompt": ("STRING", {"multiline": True, "default": ""})
} }
} }
@ -340,9 +342,8 @@ class OneShotInstructTokenize(CustomNode):
RETURN_TYPES = (TOKENS_TYPE_NAME,) RETURN_TYPES = (TOKENS_TYPE_NAME,)
FUNCTION = "execute" FUNCTION = "execute"
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: Optional[str] = _AUTO_CHAT_TEMPLATE, system_prompt: str = "") -> ValidatedNodeResult: def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, videos: list | object = None, chat_template: Optional[str] = _AUTO_CHAT_TEMPLATE, system_prompt: str = "") -> ValidatedNodeResult:
if chat_template == _AUTO_CHAT_TEMPLATE: if chat_template == _AUTO_CHAT_TEMPLATE:
# use an exact match
model_name = os.path.basename(model.repo_id) model_name = os.path.basename(model.repo_id)
if model_name in KNOWN_CHAT_TEMPLATES: if model_name in KNOWN_CHAT_TEMPLATES:
chat_template = KNOWN_CHAT_TEMPLATES[model_name] chat_template = KNOWN_CHAT_TEMPLATES[model_name]
@ -351,22 +352,43 @@ class OneShotInstructTokenize(CustomNode):
elif chat_template is not None: elif chat_template is not None:
chat_template = KNOWN_CHAT_TEMPLATES[chat_template] chat_template = KNOWN_CHAT_TEMPLATES[chat_template]
video_tensors = []
if videos is not None:
if not isinstance(videos, list):
videos_list = [videos]
else:
videos_list = videos
for vid in videos_list:
if hasattr(vid, "get_components"):
components = vid.get_components()
video_tensors.append(components.images)
elif isinstance(vid, torch.Tensor):
video_tensors.append(vid)
messages: LanguagePrompt | str messages: LanguagePrompt | str
if system_prompt != "":
messages: LanguagePrompt = [ has_images = images is not None and len(images) > 0
{"role": "system", has_videos = len(video_tensors) > 0
"content": system_prompt},
{"role": "user", if system_prompt != "" or has_images or has_videos:
"content": [ user_content = [{"type": "text", "text": prompt}]
{"type": "text", if has_images:
"text": prompt} user_content += [{"type": "image"} for _ in range(len(images))]
] + [
{"type": "image"} for _ in range(len(images) if images is not None else 0) if has_videos:
], } user_content += [{"type": "video"} for _ in range(len(video_tensors))]
messages = [
{"role": "user", "content": user_content}
] ]
if system_prompt.strip() != "":
messages.insert(0, {"role": "system", "content": system_prompt})
else: else:
messages: str = prompt messages = prompt
return model.tokenize(messages, images, chat_template),
return model.tokenize(messages, images, video_tensors, chat_template),
class TransformersGenerate(CustomNode): class TransformersGenerate(CustomNode):
@ -376,8 +398,7 @@ class TransformersGenerate(CustomNode):
"required": { "required": {
"model": ("MODEL",), "model": ("MODEL",),
"tokens": (TOKENS_TYPE_NAME, {}), "tokens": (TOKENS_TYPE_NAME, {}),
"max_new_tokens": ("INT", {"default": 512, "min": 1}), "max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 0xffffffff}),
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
"seed": Seed, "seed": Seed,
}, },
"optional": { "optional": {
@ -393,11 +414,10 @@ class TransformersGenerate(CustomNode):
model: Optional[LanguageModel] = None, model: Optional[LanguageModel] = None,
tokens: TOKENS_TYPE = None, tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512, max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0, seed: int = 0,
sampler: Optional[GENERATION_KWARGS_TYPE] = None, sampler: Optional[GENERATION_KWARGS_TYPE] = None,
): ):
return model.generate(tokens, max_new_tokens, repetition_penalty, seed, sampler), return model.generate(tokens, max_new_tokens, seed, sampler),
class PreviewString(CustomNode): class PreviewString(CustomNode):
@ -427,7 +447,7 @@ class SaveString(CustomNode):
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
}, },
"optional": { "optional": {
"extension": ("STRING", {"default": ".json"}) "extension": ("STRING", {"default": ".txt"})
} }
} }
@ -439,17 +459,53 @@ class SaveString(CustomNode):
def get_save_path(self, filename_prefix) -> SaveImagePathTuple: def get_save_path(self, filename_prefix) -> SaveImagePathTuple:
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0) return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"): def execute(self, value: str | list[str] = "", filename_prefix: str = "ComfyUI", extension: str = ".txt"):
full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix) full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix)
if isinstance(value, str): if isinstance(value, str):
value = [value] value = [value]
for i, value_i in enumerate(value): for i, value_i in enumerate(value):
# roughly matches the behavior of save image, but does not support batch numbers # roughly matches the behavior of save image, but does not support batch numbers
with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}_{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}_{extension}"), "wt+") as f: with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}{extension}"), "wt+") as f:
f.write(value_i) f.write(value_i)
return {"ui": {"string": value}} return {"ui": {"string": value}}
class OmitThink(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"forceInput": True}),
},
}
CATEGORY = "strings"
FUNCTION = "execute"
OUTPUT_NODE = True
RETURN_TYPES = ("STRING",)
def execute(self, value: str | list[str] = "") -> tuple[list[str]]:
pattern_explicit = r"<think>.*?</think>"
pattern_missing_start = r"^.*?</think>"
if isinstance(value, str):
values = [value]
else:
values = value
result = []
for value in values:
if "<think>" in value:
cleaned_text = re.sub(pattern_explicit, "", value, flags=re.DOTALL)
elif "</think>" in value:
cleaned_text = re.sub(pattern_missing_start, "", value, flags=re.DOTALL)
else:
cleaned_text = value
result.append(cleaned_text.strip())
return result,
export_custom_nodes() export_custom_nodes()
export_package_as_web_directory("comfy_extras.language_web") export_package_as_web_directory("comfy_extras.language_web")

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import io
import json import json
import logging import logging
import os import os
@ -16,6 +15,7 @@ from typing import Sequence, Optional, TypedDict, List, Literal, Tuple, Any, Dic
import PIL import PIL
import aiohttp import aiohttp
import av
import certifi import certifi
import cv2 import cv2
import fsspec import fsspec
@ -872,4 +872,162 @@ class ImageRequestParameter(CustomNode):
return ImageMaskTuple(output_images_batched, output_masks_batched) return ImageMaskTuple(output_images_batched, output_masks_batched)
class LoadImageFromURL(ImageRequestParameter):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": ""})
},
"optional": {
"default_if_empty": ("IMAGE",),
"alpha_is_transparency": ("BOOLEAN", {"default": False}),
}
}
def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple:
return super().execute(value, default_if_empty, alpha_is_transparency, *args, **kwargs)
class VideoRequestParameter(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": ""})
},
"optional": {
**_open_api_common_schema,
"default_if_empty": ("VIDEO",),
"frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1, "tooltip": "0 for no limit, otherwise stop loading after N frames"}),
"skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}),
"select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}),
}
}
RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT")
RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps")
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs) -> tuple[Tensor, Tensor, int, float]:
if value.strip() == "":
if default_if_empty is None:
return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0)
frames = default_if_empty.shape[0] if isinstance(default_if_empty, torch.Tensor) else 0
height = default_if_empty.shape[1] if frames > 0 else 1
width = default_if_empty.shape[2] if frames > 0 else 1
default_mask = torch.ones((frames, height, width), dtype=torch.float32)
return (default_if_empty, default_mask, frames, 0.0)
output_videos = []
output_masks = []
total_frames_loaded = 0
fps = 0.0
fsspec_kwargs = {}
if value.startswith('http'):
fsspec_kwargs.update({
"headers": {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36'
},
'get_client': get_client
})
with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files:
for f in files:
try:
container = av.open(f)
except Exception as e:
logger.error(f"VideoRequestParameter: Failed to open video container for {value}: {e}")
continue
if len(container.streams.video) == 0:
continue
stream = container.streams.video[0]
stream.thread_type = "AUTO"
if fps == 0.0:
fps = float(stream.average_rate)
frames_list = []
masks_list = []
frames_processed = 0
frames_kept = 0
for frame in container.decode(stream):
frames_processed += 1
if frames_processed <= skip_first_frames:
continue
if (frames_processed - skip_first_frames - 1) % select_every_nth != 0:
continue
np_frame = frame.to_ndarray(format="rgba")
tensor_img = torch.from_numpy(np_frame[..., :3]).float() / 255.0
frames_list.append(tensor_img)
tensor_mask = torch.from_numpy(np_frame[..., 3]).float() / 255.0
masks_list.append(tensor_mask)
frames_kept += 1
if frame_load_cap > 0 and frames_kept >= frame_load_cap:
break
container.close()
if frames_list:
video_tensor = torch.stack(frames_list)
mask_tensor = torch.stack(masks_list)
output_videos.append(video_tensor)
output_masks.append(mask_tensor)
total_frames_loaded += frames_kept
if not output_videos:
if default_if_empty is not None:
frames = default_if_empty.shape[0]
height = default_if_empty.shape[1]
width = default_if_empty.shape[2]
return (default_if_empty, torch.ones((frames, height, width), dtype=torch.float32), frames, 0.0)
return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0)
try:
final_video = torch.cat(output_videos, dim=0)
final_mask = torch.cat(output_masks, dim=0)
except RuntimeError:
logger.warning("VideoRequestParameter: Video resolutions mismatch in input list. Returning only the first video.")
final_video = output_videos[0]
final_mask = output_masks[0]
total_frames_loaded = final_video.shape[0]
return (final_video, final_mask, total_frames_loaded, fps)
class LoadVideoFromURL(VideoRequestParameter):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": ""})
},
"optional": {
"default_if_empty": ("VIDEO",),
"frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1}),
"skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}),
"select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}),
}
}
RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT")
RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps")
def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs):
return super().execute(value, default_if_empty, frame_load_cap, skip_first_frames, select_every_nth, *args, **kwargs)
export_custom_nodes() export_custom_nodes()

View File

@ -46,6 +46,8 @@ def write_atomic(
torch._inductor.codecache.write_atomic = write_atomic torch._inductor.codecache.write_atomic = write_atomic
# torch._inductor.utils.is_big_gpu = lambda *args: True # torch._inductor.utils.is_big_gpu = lambda *args: True
@ -98,7 +100,18 @@ class TorchCompileModel(CustomNode):
"make_refittable": True, "make_refittable": True,
} }
del compile_kwargs["mode"] del compile_kwargs["mode"]
if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)): if isinstance(model, TransformersManagedModel):
to_return = model.clone()
model = to_return.model
model_management.unload_all_models()
model.to(device=model_management.get_torch_device())
res = torch.compile(model=model, **compile_kwargs),
model.to(device=model_management.unet_offload_device())
to_return.add_object_patch("model", res)
return to_return,
elif isinstance(model, (ModelPatcher, VAE)):
to_return = model.clone() to_return = model.clone()
object_patches = [p.strip() for p in object_patch.split(",")] object_patches = [p.strip() for p in object_patch.split(",")]
patcher: ModelPatcher patcher: ModelPatcher