diff --git a/README.md b/README.md index 37493ff43..20771bd03 100644 --- a/README.md +++ b/README.md @@ -479,6 +479,28 @@ comfyui --use-sage-attention ![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp) **With PyTorch Attention** +## Cosmos Prompt Upsampling + +The Cosmos prompt "upsampler," a fine tune of Mistral-Nemo-12b, correctly rewrites Cosmos prompts in the narrative style that NVIDIA's captioner used for the training data of Cosmos, improving generation results significantly. + +Here is a comparison between a simple and "upsampled" prompt. + +![prompt_upsampling_01.webp](docs/assets/prompt_upsampling_01.webp) +**A dog is playing with a ball.** + +![prompt_upsampling_02.webp](docs/assets/prompt_upsampling_02.webp) +**In a sun-drenched park, a playful golden retriever bounds joyfully across the lush green grass, its tail wagging with excitement. The dog, adorned with a vibrant red collar, is captivated by a bright yellow ball, which it cradles gently in its mouth. The camera captures the dog's animated expressions, from eager anticipation to sheer delight, as it trots and leaps, showcasing its agility and enthusiasm. The scene is bathed in warm, golden-hour light, enhancing the vibrant colors of the dog's fur and the ball. The background features a serene tree line, framing the playful interaction and creating a tranquil atmosphere. The static camera angle allows for an intimate focus on the dog's joyful antics, inviting viewers to share in this heartwarming moment of pure canine happiness.** + +To use the Cosmos upsampler, install the prerequisites: + +```shell +uv pip install loguru pynvml +uv pip install --no-deps git+https://github.com/NVIDIA/Cosmos.git +``` +Then, use the workflow embedded in the upsampled prompt by dragging and dropping the upsampled animation into your workspace. + +The Cosmos upsampler ought to improve any text-to-image video generation pipeline. Use the `Video2World` upsampler nodes to download Pixtral-12b and upsample for an image to video workflow using NVIDIA's default prompt. Since Pixtral is not fine tuned, the improvement may not be significant over using another LLM. + # Custom Nodes Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory. diff --git a/comfy/component_model/empty_init.py b/comfy/component_model/empty_init.py new file mode 100644 index 000000000..b74a359e8 --- /dev/null +++ b/comfy/component_model/empty_init.py @@ -0,0 +1,19 @@ +import torch +import torch.overrides +import torch.utils._device + + +class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): + def __init__(self, device=None): + self.device = device + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, '__module__', None) == 'torch.nn.init': + if 'tensor' in kwargs: + return kwargs['tensor'] + else: + return args[0] + if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None: + kwargs['device'] = self.device + return func(*args, **kwargs) diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index bcb34459d..224c0716c 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -6,6 +6,7 @@ import inspect import logging import operator import pathlib +import weakref from functools import reduce from typing import Optional, Any, Callable @@ -43,9 +44,10 @@ class TransformersManagedModel(ModelManageable, LanguageModel): processor: Optional[ProcessorMixin | AutoProcessor] = None ): self._repo_id = repo_id - self.model = model + self._model = model self._tokenizer = tokenizer self._processor = processor + self._object_patches: dict[str, Any] = {} self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values()) self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values()) self.load_device = get_torch_device() @@ -53,6 +55,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel): self._config_dict = config_dict self._on_set_processor(self._processor) self._model_type = "" + self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self) if model.device != self.offload_device: model.to(device=self.offload_device) @@ -426,6 +429,28 @@ class TransformersManagedModel(ModelManageable, LanguageModel): else: return f"" + def clone(self) -> TransformersManagedModel: + m = copy.copy(self) + # deep copy a few objects + m._object_patches = copy.copy(self._object_patches) + return m + + def add_object_patch(self, name: str, obj: Any): + # for the sake of compatibility, rewrite the name to the actual model field + if name == "diffusion_model": + name = "model" + + self._object_patches[name] = obj + + def get_model_object(self, name: str) -> torch.nn.Module: + if name == "diffusion_model": + name = "model" + return super().get_model_object(name) + + @property + def model(self) -> PreTrainedModel | torch.nn.Module: + return self._object_patches.get("model", self._model) + class _ProgressTextStreamer(TextStreamer): def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): diff --git a/comfy/model_management.py b/comfy/model_management.py index 00cdc07f5..365ac26e9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -406,7 +406,7 @@ class LoadedModel: self._set_model(model) @property - def model(self): + def model(self) -> ModelManageable: return self._model() def model_memory(self): diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 266f4e30d..ceda5f246 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any +from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple import torch import torch.nn @@ -150,6 +150,8 @@ class MemoryMeasurements: def device(self) -> torch.device: if isinstance(self.model, DeviceSettable): return self.model.device + elif hasattr(self.model, "device"): + return self.model.device else: return self._device @@ -157,6 +159,8 @@ class MemoryMeasurements: def device(self, value: torch.device): if isinstance(self.model, DeviceSettable): self.model.device = value + elif hasattr(self.model, "to"): + self.model.to(value) self._device = value @@ -175,3 +179,9 @@ class ModelOptions(TypedDict, total=False): disable_cfg1_optimization: NotRequired[bool] denoise_mask_function: NotRequired[Callable] patches: NotRequired[dict[str, list]] + +class LoadingListItem(NamedTuple): + module_size: int + name: str + module: torch.nn.Module + params: list[str] \ No newline at end of file diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 68e07cc03..81381fcd0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -39,7 +39,7 @@ from .float import stochastic_rounding from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue from .model_base import BaseModel -from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT +from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection logger = logging.getLogger(__name__) @@ -627,7 +627,8 @@ class ModelPatcher(ModelManageable): else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) - def _load_list(self): + + def _load_list(self) -> list[LoadingListItem]: loading = [] for n, m in self.model.named_modules(): params = [] @@ -639,7 +640,7 @@ class ModelPatcher(ModelManageable): skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((model_management.module_size(m), n, m, params)) + loading.append(LoadingListItem(model_management.module_size(m), n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -650,13 +651,13 @@ class ModelPatcher(ModelManageable): lowvram_counter = 0 loading = self._load_list() - load_completely = [] + load_completely: list[LoadingListItem] = [] loading.sort(reverse=True) for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + n = x.name + m = x.module + params = x.params + module_mem = x.module_size lowvram_weight = False @@ -696,7 +697,7 @@ class ModelPatcher(ModelManageable): if full_load or mem_counter + module_mem < lowvram_model_memory: mem_counter += module_mem - load_completely.append((module_mem, n, m, params)) + load_completely.append(LoadingListItem(module_mem, n, m, params)) if cast_weight: m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -712,9 +713,9 @@ class ModelPatcher(ModelManageable): load_completely.sort(reverse=True) for x in load_completely: - n = x[1] - m = x[2] - params = x[3] + n = x.name + m = x.module + params = x.params if hasattr(m, "comfy_patched_weights"): if m.comfy_patched_weights == True: continue @@ -726,7 +727,7 @@ class ModelPatcher(ModelManageable): m.comfy_patched_weights = True for x in load_completely: - x[2].to(device_to) + x.module.to(device_to) if lowvram_counter > 0: logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) @@ -791,7 +792,9 @@ class ModelPatcher(ModelManageable): self.backup.clear() if device_to is not None: - self.model.to(device_to) + if hasattr(self.model, "to"): + # todo: is this now redundant with self.model.to? + self.model.to(device_to) self.model_device = device_to self._memory_measurements.model_loaded_weight_memory = 0 diff --git a/comfy_extras/nodes/nodes_cosmos.py b/comfy_extras/nodes/nodes_cosmos.py index 1a1895c51..5e20bf487 100644 --- a/comfy_extras/nodes/nodes_cosmos.py +++ b/comfy_extras/nodes/nodes_cosmos.py @@ -2,11 +2,9 @@ import torch import comfy.model_management import comfy.utils -from comfy.language.language_types import LanguageModel from comfy.node_helpers import export_custom_nodes from comfy.nodes.common import MAX_RESOLUTION -from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult -from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize, _AUTO_CHAT_TEMPLATE +from comfy.nodes.package_typing import CustomNode class EmptyCosmosLatentVideo(CustomNode): @@ -81,32 +79,4 @@ class CosmosImageToVideoLatent(CustomNode): return (out_latent,) -class CosmosPromptUpsamplerTransformersLoader(TransformersLoader): - @classmethod - def INPUT_TYPES(cls) -> InputTypes: - return { - "required": { - "ckpt_name": ("STRING", {}), - }, - } - - -# from https://github.com/NVIDIA/Cosmos/blob/b867572b99d08f450ddb8bcd6661d8c35bf6b967/cosmos1/models/diffusion/nemo/inference/inference_utils.py#L54 -FROM_COSMOS_REPO_PROMPT_PREFIX = "Upsample the short caption to a long caption: " - - -class CosmosUpsamplePromptTokenize(OneShotInstructTokenize): - @classmethod - def INPUT_TYPES(cls) -> InputTypes: - return { - "required": { - "model": ("MODEL",), - "prompt": ("STRING", {"default": "", "multiline": True}), - }, - } - - def execute(self, model: LanguageModel, prompt: str, images: list[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult: - return super().execute(model, f"{FROM_COSMOS_REPO_PROMPT_PREFIX}{prompt}", images=None, chat_template=_AUTO_CHAT_TEMPLATE) - - export_custom_nodes() diff --git a/comfy_extras/nodes/nodes_cosmos_upsampling.py b/comfy_extras/nodes/nodes_cosmos_upsampling.py new file mode 100644 index 000000000..51d165fdf --- /dev/null +++ b/comfy_extras/nodes/nodes_cosmos_upsampling.py @@ -0,0 +1,207 @@ +import re +from pathlib import Path +from typing import Optional + +import torch + +from comfy import model_management +from comfy.component_model.empty_init import EmptyInitOnDevice +from comfy.component_model.tensor_types import RGBImageBatch +from comfy.language.language_types import LanguageModel, ProcessorResult, LanguagePrompt, GENERATION_KWARGS_TYPE, \ + TOKENS_TYPE +from comfy.model_downloader import get_or_download_huggingface_repo +from comfy.model_management import load_models_gpu +from comfy.model_patcher import ModelPatcher +from comfy.node_helpers import export_custom_nodes +from comfy.nodes.package_typing import InputTypes, ValidatedNodeResult, CustomNode +from comfy_extras.nodes.nodes_language import TransformersLoader, TransformersTokenize, OneShotInstructTokenize, \ + _AUTO_CHAT_TEMPLATE + +# from https://github.com/NVIDIA/Cosmos/blob/b867572b99d08f450ddb8bcd6661d8c35bf6b967/cosmos1/models/diffusion/nemo/inference/inference_utils.py#L54 +COSMOS_TEXT_TO_WORLD_UPSAMPLE_TASK = "Upsample the short caption to a long caption: " +COSMOS_VIDEO_TO_WORLD_UPSAMPLE_TASK = """ +Your task is to transform a given prompt into a refined and concise video description, no more than 150 words. +Focus only on the content, no filler words or descriptions on the style. Never mention things outside the video. +""" + + +def clean_text(text: str) -> str: + """Clean the text by removing prefixes, suffixes, formatting markers, and normalizing whitespace.""" + # Replace all variations of newlines with a space + text = text.replace("\n", " ").replace("\r", " ") + + # Use a regex to find sections of the form '- **...**' + pattern = r"(- \*\*)(.*?)(\*\*)" + + def replacement(match: re.Match[str]) -> str: + content = match.group(2) # The text inside - ** and ** + words = re.findall(r"\w+", content) + if len(words) < 10: + # If fewer than 10 words, remove the entire '- **...**' portion + return "" + else: + # If 10 or more words, keep the entire section as it is + return match.group(0) + + text = re.sub(pattern, replacement, text) + + # Remove common prefixes + prefixes = ["Caption:", "#####", "####", "- ", "* ", ","] + for prefix in prefixes: + # lstrip(prefix) won't strip entire strings, but character sets. + # For more reliable prefix removal, do: + if text.startswith(prefix): + text = text[len(prefix):].lstrip() + + # Remove extra spaces + text = " ".join(text.split()) + + # Strip any remaining leading/trailing punctuation, whitespace, and quotes + text = text.strip(' -,*:"\'"“”') + + return text + + +class CosmosPromptUpsamplerTransformersLoader(TransformersLoader): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "ckpt_name": (["unsloth/Pixtral-12B-2409"], {}), + }, + } + + +class Mistral12b(LanguageModel): + def __init__(self, model: ModelPatcher, ckpt_name: str): + self.model = model + self.ckpt_name = ckpt_name + + @staticmethod + def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "Mistral12b": + from cosmos1.models.autoregressive.configs.base.model_config import create_text_model_config # pylint: disable=import-error + from cosmos1.models.autoregressive.model import AutoRegressiveModel # pylint: disable=import-error + checkpoint_dir = get_or_download_huggingface_repo(ckpt_name) + assert checkpoint_dir is not None, f"did not successfully download {ckpt_name}" + checkpoint_dir = Path(checkpoint_dir) + model_config, tokenizer_config = create_text_model_config( + model_ckpt_path=str(checkpoint_dir / "model.pt"), + tokenizer_path=str(checkpoint_dir), + model_family="mistral", + model_size="12b", + is_instruct_model=True, + max_batch_size=1, + rope_dim="1D", + add_special_tokens=True, + max_seq_len=1024, + pytorch_rope_version="v1", + ) + + try: + with EmptyInitOnDevice(device=model_management.unet_offload_device()): + completion_instance_cpu = AutoRegressiveModel.build(model_config=model_config, tokenizer_config=tokenizer_config) + finally: + torch.set_default_dtype(torch.float32) + + patchable_completion_instance_cpu = ModelPatcher(completion_instance_cpu, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), size=completion_instance_cpu.get_num_params() * 2, ckpt_name=ckpt_name) + return Mistral12b(patchable_completion_instance_cpu, ckpt_name=ckpt_name) + + def generate(self, + tokens: TOKENS_TYPE = None, + max_new_tokens: int = 512, + repetition_penalty: float = 0.0, + seed: int = 0, + sampler: Optional[GENERATION_KWARGS_TYPE] = None, + *args, + **kwargs) -> str: + sampler = sampler or {} + prompt = tokens.get("inputs", []) + prompt = "".join(prompt) + + dialogs = [[{"role": "user", "content": prompt}]] + + from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion # pylint: disable=import-error + from cosmos1.models.autoregressive.model import AutoRegressiveModel # pylint: disable=import-error + + load_models_gpu([self.model]) + + # noinspection PyTypeChecker + model: AutoRegressiveModel = self.model.model + assert isinstance(model, AutoRegressiveModel) + + results = chat_completion( + model, + dialogs, + seed=seed, + max_gen_len=max_new_tokens, + temperature=sampler.get("temperature", 0.01), + top_p=sampler.get("top_p", None), + top_k=sampler.get("top_k", None), + logprobs=False, + ) + + upsampled_prompt = str(clean_text(results[0]["generation"]["content"])) + return upsampled_prompt + + def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult: + # Return prompts and image as is + return { + "inputs": [prompt], + "attention_mask": torch.ones(1, len(prompt)), # Dummy attention mask + "images": images + } + + @property + def repo_id(self) -> str: + return self.ckpt_name + + +class CosmosText2WorldPromptUpsamplerLoader(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "ckpt_name": (["nvidia/Cosmos-1.0-Prompt-Upsampler-12B-Text2World"], {}), + }, + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("upsampler model",) + CATEGORY = "cosmos" + FUNCTION = "execute" + + def execute(self, ckpt_name: str) -> tuple[LanguageModel]: + return Mistral12b.from_pretrained(ckpt_name), + + +class CosmosText2WorldUpsamplePromptTokenize(TransformersTokenize): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": ("MODEL",), + "prompt": ("STRING", {"default": "", "multiline": True}), + }, + } + + def execute(self, model: LanguageModel, prompt: str) -> ValidatedNodeResult: + return super().execute(model, f"{COSMOS_TEXT_TO_WORLD_UPSAMPLE_TASK}{prompt}") + + +class CosmosVideo2WorldUpsamplePromptTokenize(OneShotInstructTokenize): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": ("MODEL", {}), + }, + "optional": { + "images": ("IMAGE", {}), + } + } + + def execute(self, model: LanguageModel, prompt: str = None, images: list[torch.Tensor] | torch.Tensor = None, chat_template: str = _AUTO_CHAT_TEMPLATE) -> ValidatedNodeResult: + return super().execute(model, COSMOS_VIDEO_TO_WORLD_UPSAMPLE_TASK, images, _AUTO_CHAT_TEMPLATE) + + +export_custom_nodes() diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py index fcb37459d..dfa6a1d19 100644 --- a/comfy_extras/nodes/nodes_torch_compile.py +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -8,6 +8,8 @@ import torch._inductor.codecache from torch.nn import LayerNorm from comfy import model_management +from comfy.language.language_types import LanguageModel +from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_patcher import ModelPatcher from comfy.nodes.package_typing import CustomNode, InputTypes @@ -97,7 +99,7 @@ class TorchCompileModel(CustomNode): } move_to_gpu = True del compile_kwargs["mode"] - if isinstance(model, ModelPatcher): + if isinstance(model, ModelPatcher) or isinstance(model, TransformersManagedModel): m = model.clone() if move_to_gpu: model_management.load_models_gpu([m]) @@ -115,12 +117,19 @@ class TorchCompileModel(CustomNode): else: logging.warning("Encountered a model that cannot be compiled") return model, - except OSError: + except OSError as os_error: try: torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member except Exception: pass - raise + raise os_error + except Exception as exc_info: + try: + torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member + except Exception: + pass + logging.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info) + return model, _QUANTIZATION_STRATEGIES = [ diff --git a/docs/assets/prompt_upsampling_01.webp b/docs/assets/prompt_upsampling_01.webp new file mode 100644 index 000000000..c54f64fd1 Binary files /dev/null and b/docs/assets/prompt_upsampling_01.webp differ diff --git a/docs/assets/prompt_upsampling_02.webp b/docs/assets/prompt_upsampling_02.webp new file mode 100644 index 000000000..8bf8933df Binary files /dev/null and b/docs/assets/prompt_upsampling_02.webp differ