diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 7a4be89a8..a98b900b0 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -1161,8 +1161,12 @@ class PromptServer(ExecutorToClientProgress): await self.send_image(data, sid=sid) elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA: # data is (preview_image, metadata) + data: PreviewImageWithMetadataMessage preview_image, metadata = data + if isinstance(preview_image, dict): + # todo: this has to be fixed from transformers loader for previewing tokens in real time + return await self.send_image_with_metadata(preview_image, metadata, sid=sid) elif isinstance(data, (bytes, bytearray)): await self.send_bytes(event, data, sid) diff --git a/comfy/language/language_types.py b/comfy/language/language_types.py index 54b6a8484..a7f93a218 100644 --- a/comfy/language/language_types.py +++ b/comfy/language/language_types.py @@ -92,14 +92,13 @@ class LanguageModel(Protocol): def generate(self, tokens: TOKENS_TYPE = None, max_new_tokens: int = 512, - repetition_penalty: float = 0.0, seed: int = 0, sampler: Optional[GENERATION_KWARGS_TYPE] = None, *args, **kwargs) -> str: ... - def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult: + def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, videos: list[torch.Tensor] | None, chat_template: str | None = None) -> ProcessorResult: ... @property diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index 5cbd3b717..31f465b43 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -19,7 +19,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixi from huggingface_hub import hf_api from huggingface_hub.file_download import hf_hub_download from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \ - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, AutoModelForImageTextToText from .chat_templates import KNOWN_CHAT_TEMPLATES from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \ @@ -122,7 +122,7 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): for i, kwargs_to_try in enumerate(kwargses_to_try): try: if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES: - model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) + model = AutoModelForImageTextToText.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: @@ -176,7 +176,6 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): def generate(self, tokens: TOKENS_TYPE = None, max_new_tokens: int = 512, - repetition_penalty: float = 0.0, seed: int = 0, sampler: Optional[GENERATION_KWARGS_TYPE] = None, *args, @@ -257,7 +256,6 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): **inputs, streamer=text_streamer if num_beams <= 1 else None, max_new_tokens=max_new_tokens, - repetition_penalty=repetition_penalty if repetition_penalty != 0 else None, **generate_kwargs ) @@ -364,7 +362,7 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"): processor.image_processor.do_rescale = False - def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult: + def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, videos: list[torch.Tensor] | None = None, chat_template: str | None = None) -> ProcessorResult: tokenizer = self.processor if self.processor is not None else self.tokenizer assert tokenizer is not None assert hasattr(tokenizer, "decode") @@ -391,16 +389,18 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): messages: LanguagePrompt if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): messages = prompt - elif "content[" in chat_template: + elif images is not None and len(images) > 0 or videos is not None and len(videos) > 0: messages = [ {"role": "user", "content": [ { "type": "text", - "text": prompt + "text": prompt if isinstance(prompt, str) else "" } ] + [ {"type": "image"} for _ in range(len(images)) + ] + [ + {"type": "video"} for _ in range(len(videos)) ] } @@ -409,6 +409,7 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): messages = [ {"role": "user", "content": prompt}, ] + prompt = tokenizer.apply_chat_template(messages, chat_template=chat_template, add_generation_prompt=True, tokenize=False) except Exception as exc: logger.debug("Could not apply chat template", exc_info=exc) @@ -422,7 +423,14 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): # convert tuple to list from images.unbind() for paligemma workaround image_tensor_list = list(images.unbind()) if images is not None and len(images) > 0 else None try: - batch_feature: BatchFeature = self.processor(text=[prompt], images=image_tensor_list, return_tensors="pt", padding=True) + batch_feature: BatchFeature = self.processor( + text=[prompt], + images=image_tensor_list, + videos=None if videos is not None and len(videos) == 0 or (hasattr(videos, "shape") and videos.shape[0]) == 0 else videos, + return_tensors="pt", + padding=True, + input_data_format="channels_last" # Ensure this is set for Qwen + ) except TypeError as exc_info: logger.warning(f"Exception while trying to run processor. Your transformers package is version {transformers.__version__} and may need to be updated") raise exc_info diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index 4c73667fc..083578a23 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -65,7 +65,7 @@ class ProgressHandler(ABC): max_value: float, state: NodeProgressState, prompt_id: str, - image: PreviewImageTuple | None = None, + image: PreviewImageTuple | dict | None = None, ): """Called when a node's progress is updated""" pass @@ -210,7 +210,7 @@ class WebUIProgressHandler(ProgressHandler): max_value: float, state: NodeProgressState, prompt_id: str, - image: PreviewImageTuple | None = None, + image: PreviewImageTuple | dict | None = None, ): # Send progress state of all nodes if self.registry: @@ -233,6 +233,7 @@ class WebUIProgressHandler(ProgressHandler): ), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), } + # todo: image can be a dict because of transformers loader, do we have to deal with this specially? things just work message: PreviewImageWithMetadataMessage = (image, metadata) self.server_instance.send_sync( BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, @@ -300,7 +301,7 @@ class ProgressRegistry(AbstractProgressRegistry): handler.start_handler(node_id, entry, self.prompt_id) def update_progress( - self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None + self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | dict | None = None ) -> None: """Update progress for a node""" entry = self.ensure_entry(node_id) diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index 349428461..ae3c7be04 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -2,6 +2,7 @@ from __future__ import annotations import operator import os.path +import re from abc import ABC, abstractmethod from functools import reduce from typing import Optional, List @@ -332,6 +333,7 @@ class OneShotInstructTokenize(CustomNode): }, "optional": { "images": ("IMAGE", {}), + "videos": ("VIDEO", {}), "system_prompt": ("STRING", {"multiline": True, "default": ""}) } } @@ -340,9 +342,8 @@ class OneShotInstructTokenize(CustomNode): RETURN_TYPES = (TOKENS_TYPE_NAME,) FUNCTION = "execute" - def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: Optional[str] = _AUTO_CHAT_TEMPLATE, system_prompt: str = "") -> ValidatedNodeResult: + def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, videos: list | object = None, chat_template: Optional[str] = _AUTO_CHAT_TEMPLATE, system_prompt: str = "") -> ValidatedNodeResult: if chat_template == _AUTO_CHAT_TEMPLATE: - # use an exact match model_name = os.path.basename(model.repo_id) if model_name in KNOWN_CHAT_TEMPLATES: chat_template = KNOWN_CHAT_TEMPLATES[model_name] @@ -351,22 +352,43 @@ class OneShotInstructTokenize(CustomNode): elif chat_template is not None: chat_template = KNOWN_CHAT_TEMPLATES[chat_template] + video_tensors = [] + if videos is not None: + if not isinstance(videos, list): + videos_list = [videos] + else: + videos_list = videos + + for vid in videos_list: + if hasattr(vid, "get_components"): + components = vid.get_components() + video_tensors.append(components.images) + elif isinstance(vid, torch.Tensor): + video_tensors.append(vid) + messages: LanguagePrompt | str - if system_prompt != "": - messages: LanguagePrompt = [ - {"role": "system", - "content": system_prompt}, - {"role": "user", - "content": [ - {"type": "text", - "text": prompt} - ] + [ - {"type": "image"} for _ in range(len(images) if images is not None else 0) - ], } + + has_images = images is not None and len(images) > 0 + has_videos = len(video_tensors) > 0 + + if system_prompt != "" or has_images or has_videos: + user_content = [{"type": "text", "text": prompt}] + if has_images: + user_content += [{"type": "image"} for _ in range(len(images))] + + if has_videos: + user_content += [{"type": "video"} for _ in range(len(video_tensors))] + + messages = [ + {"role": "user", "content": user_content} ] + + if system_prompt.strip() != "": + messages.insert(0, {"role": "system", "content": system_prompt}) else: - messages: str = prompt - return model.tokenize(messages, images, chat_template), + messages = prompt + + return model.tokenize(messages, images, video_tensors, chat_template), class TransformersGenerate(CustomNode): @@ -376,8 +398,7 @@ class TransformersGenerate(CustomNode): "required": { "model": ("MODEL",), "tokens": (TOKENS_TYPE_NAME, {}), - "max_new_tokens": ("INT", {"default": 512, "min": 1}), - "repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}), + "max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 0xffffffff}), "seed": Seed, }, "optional": { @@ -393,11 +414,10 @@ class TransformersGenerate(CustomNode): model: Optional[LanguageModel] = None, tokens: TOKENS_TYPE = None, max_new_tokens: int = 512, - repetition_penalty: float = 0.0, seed: int = 0, sampler: Optional[GENERATION_KWARGS_TYPE] = None, ): - return model.generate(tokens, max_new_tokens, repetition_penalty, seed, sampler), + return model.generate(tokens, max_new_tokens, seed, sampler), class PreviewString(CustomNode): @@ -427,7 +447,7 @@ class SaveString(CustomNode): "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) }, "optional": { - "extension": ("STRING", {"default": ".json"}) + "extension": ("STRING", {"default": ".txt"}) } } @@ -439,17 +459,53 @@ class SaveString(CustomNode): def get_save_path(self, filename_prefix) -> SaveImagePathTuple: return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0) - def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"): + def execute(self, value: str | list[str] = "", filename_prefix: str = "ComfyUI", extension: str = ".txt"): full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix) if isinstance(value, str): value = [value] for i, value_i in enumerate(value): # roughly matches the behavior of save image, but does not support batch numbers - with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}_{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}_{extension}"), "wt+") as f: + with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}{extension}"), "wt+") as f: f.write(value_i) return {"ui": {"string": value}} +class OmitThink(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("STRING", {"forceInput": True}), + }, + } + + CATEGORY = "strings" + FUNCTION = "execute" + OUTPUT_NODE = True + RETURN_TYPES = ("STRING",) + + def execute(self, value: str | list[str] = "") -> tuple[list[str]]: + pattern_explicit = r".*?" + pattern_missing_start = r"^.*?" + + if isinstance(value, str): + values = [value] + else: + values = value + + result = [] + for value in values: + if "" in value: + cleaned_text = re.sub(pattern_explicit, "", value, flags=re.DOTALL) + elif "" in value: + cleaned_text = re.sub(pattern_missing_start, "", value, flags=re.DOTALL) + else: + cleaned_text = value + result.append(cleaned_text.strip()) + + return result, + + export_custom_nodes() export_package_as_web_directory("comfy_extras.language_web") diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index 2ecbe56df..839af3602 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -1,7 +1,6 @@ from __future__ import annotations import dataclasses -import io import json import logging import os @@ -16,6 +15,7 @@ from typing import Sequence, Optional, TypedDict, List, Literal, Tuple, Any, Dic import PIL import aiohttp +import av import certifi import cv2 import fsspec @@ -872,4 +872,162 @@ class ImageRequestParameter(CustomNode): return ImageMaskTuple(output_images_batched, output_masks_batched) +class LoadImageFromURL(ImageRequestParameter): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("STRING", {"default": ""}) + }, + "optional": { + "default_if_empty": ("IMAGE",), + "alpha_is_transparency": ("BOOLEAN", {"default": False}), + } + } + + def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple: + return super().execute(value, default_if_empty, alpha_is_transparency, *args, **kwargs) + + +class VideoRequestParameter(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("STRING", {"default": ""}) + }, + "optional": { + **_open_api_common_schema, + "default_if_empty": ("VIDEO",), + "frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1, "tooltip": "0 for no limit, otherwise stop loading after N frames"}), + "skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}), + "select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}), + } + } + + RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT") + RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps") + FUNCTION = "execute" + CATEGORY = "api/openapi" + + def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs) -> tuple[Tensor, Tensor, int, float]: + if value.strip() == "": + if default_if_empty is None: + return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0) + + frames = default_if_empty.shape[0] if isinstance(default_if_empty, torch.Tensor) else 0 + height = default_if_empty.shape[1] if frames > 0 else 1 + width = default_if_empty.shape[2] if frames > 0 else 1 + + default_mask = torch.ones((frames, height, width), dtype=torch.float32) + return (default_if_empty, default_mask, frames, 0.0) + + output_videos = [] + output_masks = [] + total_frames_loaded = 0 + fps = 0.0 + + fsspec_kwargs = {} + if value.startswith('http'): + fsspec_kwargs.update({ + "headers": { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5672.64 Safari/537.36' + }, + 'get_client': get_client + }) + + with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files: + for f in files: + try: + container = av.open(f) + except Exception as e: + logger.error(f"VideoRequestParameter: Failed to open video container for {value}: {e}") + continue + + if len(container.streams.video) == 0: + continue + + stream = container.streams.video[0] + stream.thread_type = "AUTO" + + if fps == 0.0: + fps = float(stream.average_rate) + + frames_list = [] + masks_list = [] + frames_processed = 0 + frames_kept = 0 + + for frame in container.decode(stream): + frames_processed += 1 + + if frames_processed <= skip_first_frames: + continue + + if (frames_processed - skip_first_frames - 1) % select_every_nth != 0: + continue + + np_frame = frame.to_ndarray(format="rgba") + tensor_img = torch.from_numpy(np_frame[..., :3]).float() / 255.0 + frames_list.append(tensor_img) + tensor_mask = torch.from_numpy(np_frame[..., 3]).float() / 255.0 + masks_list.append(tensor_mask) + + frames_kept += 1 + + if frame_load_cap > 0 and frames_kept >= frame_load_cap: + break + + container.close() + + if frames_list: + video_tensor = torch.stack(frames_list) + mask_tensor = torch.stack(masks_list) + + output_videos.append(video_tensor) + output_masks.append(mask_tensor) + total_frames_loaded += frames_kept + + if not output_videos: + if default_if_empty is not None: + frames = default_if_empty.shape[0] + height = default_if_empty.shape[1] + width = default_if_empty.shape[2] + return (default_if_empty, torch.ones((frames, height, width), dtype=torch.float32), frames, 0.0) + return (torch.zeros((0, 1, 1, 3)), torch.zeros((0, 1, 1)), 0, 0.0) + + try: + final_video = torch.cat(output_videos, dim=0) + final_mask = torch.cat(output_masks, dim=0) + except RuntimeError: + logger.warning("VideoRequestParameter: Video resolutions mismatch in input list. Returning only the first video.") + final_video = output_videos[0] + final_mask = output_masks[0] + total_frames_loaded = final_video.shape[0] + + return (final_video, final_mask, total_frames_loaded, fps) + + +class LoadVideoFromURL(VideoRequestParameter): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("STRING", {"default": ""}) + }, + "optional": { + "default_if_empty": ("VIDEO",), + "frame_load_cap": ("INT", {"default": 0, "min": 0, "step": 1}), + "skip_first_frames": ("INT", {"default": 0, "min": 0, "step": 1}), + "select_every_nth": ("INT", {"default": 1, "min": 1, "step": 1}), + } + } + + RETURN_TYPES = ("VIDEO", "MASK", "INT", "FLOAT") + RETURN_NAMES = ("VIDEO", "MASK", "frame_count", "fps") + + def execute(self, value: str = "", default_if_empty=None, frame_load_cap=0, skip_first_frames=0, select_every_nth=1, *args, **kwargs): + return super().execute(value, default_if_empty, frame_load_cap, skip_first_frames, select_every_nth, *args, **kwargs) + + export_custom_nodes() diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py index 2f4403e98..6a4a3f476 100644 --- a/comfy_extras/nodes/nodes_torch_compile.py +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -46,6 +46,8 @@ def write_atomic( torch._inductor.codecache.write_atomic = write_atomic + + # torch._inductor.utils.is_big_gpu = lambda *args: True @@ -98,7 +100,18 @@ class TorchCompileModel(CustomNode): "make_refittable": True, } del compile_kwargs["mode"] - if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)): + if isinstance(model, TransformersManagedModel): + to_return = model.clone() + model = to_return.model + + model_management.unload_all_models() + model.to(device=model_management.get_torch_device()) + res = torch.compile(model=model, **compile_kwargs), + model.to(device=model_management.unet_offload_device()) + + to_return.add_object_patch("model", res) + return to_return, + elif isinstance(model, (ModelPatcher, VAE)): to_return = model.clone() object_patches = [p.strip() for p in object_patch.split(",")] patcher: ModelPatcher