Prompt upsampling, better torch.compile support for language models

This commit is contained in:
doctorpangloss 2025-03-03 18:36:47 -08:00
parent c6111fae7d
commit d82261485f
11 changed files with 316 additions and 51 deletions

View File

@ -479,6 +479,28 @@ comfyui --use-sage-attention
![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp)
**With PyTorch Attention**
## Cosmos Prompt Upsampling
The Cosmos prompt "upsampler," a fine tune of Mistral-Nemo-12b, correctly rewrites Cosmos prompts in the narrative style that NVIDIA's captioner used for the training data of Cosmos, improving generation results significantly.
Here is a comparison between a simple and "upsampled" prompt.
![prompt_upsampling_01.webp](docs/assets/prompt_upsampling_01.webp)
**A dog is playing with a ball.**
![prompt_upsampling_02.webp](docs/assets/prompt_upsampling_02.webp)
**In a sun-drenched park, a playful golden retriever bounds joyfully across the lush green grass, its tail wagging with excitement. The dog, adorned with a vibrant red collar, is captivated by a bright yellow ball, which it cradles gently in its mouth. The camera captures the dog's animated expressions, from eager anticipation to sheer delight, as it trots and leaps, showcasing its agility and enthusiasm. The scene is bathed in warm, golden-hour light, enhancing the vibrant colors of the dog's fur and the ball. The background features a serene tree line, framing the playful interaction and creating a tranquil atmosphere. The static camera angle allows for an intimate focus on the dog's joyful antics, inviting viewers to share in this heartwarming moment of pure canine happiness.**
To use the Cosmos upsampler, install the prerequisites:
```shell
uv pip install loguru pynvml
uv pip install --no-deps git+https://github.com/NVIDIA/Cosmos.git
```
Then, use the workflow embedded in the upsampled prompt by dragging and dropping the upsampled animation into your workspace.
The Cosmos upsampler ought to improve any text-to-image video generation pipeline. Use the `Video2World` upsampler nodes to download Pixtral-12b and upsample for an image to video workflow using NVIDIA's default prompt. Since Pixtral is not fine tuned, the improvement may not be significant over using another LLM.
# Custom Nodes
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.

View File

@ -0,0 +1,19 @@
import torch
import torch.overrides
import torch.utils._device
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
def __init__(self, device=None):
self.device = device
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, '__module__', None) == 'torch.nn.init':
if 'tensor' in kwargs:
return kwargs['tensor']
else:
return args[0]
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)

View File

@ -6,6 +6,7 @@ import inspect
import logging
import operator
import pathlib
import weakref
from functools import reduce
from typing import Optional, Any, Callable
@ -43,9 +44,10 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
processor: Optional[ProcessorMixin | AutoProcessor] = None
):
self._repo_id = repo_id
self.model = model
self._model = model
self._tokenizer = tokenizer
self._processor = processor
self._object_patches: dict[str, Any] = {}
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
self.load_device = get_torch_device()
@ -53,6 +55,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
self._config_dict = config_dict
self._on_set_processor(self._processor)
self._model_type = ""
self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self)
if model.device != self.offload_device:
model.to(device=self.offload_device)
@ -426,6 +429,28 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
else:
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
def clone(self) -> TransformersManagedModel:
m = copy.copy(self)
# deep copy a few objects
m._object_patches = copy.copy(self._object_patches)
return m
def add_object_patch(self, name: str, obj: Any):
# for the sake of compatibility, rewrite the name to the actual model field
if name == "diffusion_model":
name = "model"
self._object_patches[name] = obj
def get_model_object(self, name: str) -> torch.nn.Module:
if name == "diffusion_model":
name = "model"
return super().get_model_object(name)
@property
def model(self) -> PreTrainedModel | torch.nn.Module:
return self._object_patches.get("model", self._model)
class _ProgressTextStreamer(TextStreamer):
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):

View File

@ -406,7 +406,7 @@ class LoadedModel:
self._set_model(model)
@property
def model(self):
def model(self) -> ModelManageable:
return self._model()
def model_memory(self):

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import dataclasses
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
import torch
import torch.nn
@ -150,6 +150,8 @@ class MemoryMeasurements:
def device(self) -> torch.device:
if isinstance(self.model, DeviceSettable):
return self.model.device
elif hasattr(self.model, "device"):
return self.model.device
else:
return self._device
@ -157,6 +159,8 @@ class MemoryMeasurements:
def device(self, value: torch.device):
if isinstance(self.model, DeviceSettable):
self.model.device = value
elif hasattr(self.model, "to"):
self.model.to(value)
self._device = value
@ -175,3 +179,9 @@ class ModelOptions(TypedDict, total=False):
disable_cfg1_optimization: NotRequired[bool]
denoise_mask_function: NotRequired[Callable]
patches: NotRequired[dict[str, list]]
class LoadingListItem(NamedTuple):
module_size: int
name: str
module: torch.nn.Module
params: list[str]

View File

@ -39,7 +39,7 @@ from .float import stochastic_rounding
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem
from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
logger = logging.getLogger(__name__)
@ -627,7 +627,8 @@ class ModelPatcher(ModelManageable):
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _load_list(self):
def _load_list(self) -> list[LoadingListItem]:
loading = []
for n, m in self.model.named_modules():
params = []
@ -639,7 +640,7 @@ class ModelPatcher(ModelManageable):
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((model_management.module_size(m), n, m, params))
loading.append(LoadingListItem(model_management.module_size(m), n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -650,13 +651,13 @@ class ModelPatcher(ModelManageable):
lowvram_counter = 0
loading = self._load_list()
load_completely = []
load_completely: list[LoadingListItem] = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
params = x[3]
module_mem = x[0]
n = x.name
m = x.module
params = x.params
module_mem = x.module_size
lowvram_weight = False
@ -696,7 +697,7 @@ class ModelPatcher(ModelManageable):
if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
load_completely.append(LoadingListItem(module_mem, n, m, params))
if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights
@ -712,9 +713,9 @@ class ModelPatcher(ModelManageable):
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
m = x[2]
params = x[3]
n = x.name
m = x.module
params = x.params
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
@ -726,7 +727,7 @@ class ModelPatcher(ModelManageable):
m.comfy_patched_weights = True
for x in load_completely:
x[2].to(device_to)
x.module.to(device_to)
if lowvram_counter > 0:
logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
@ -791,6 +792,8 @@ class ModelPatcher(ModelManageable):
self.backup.clear()
if device_to is not None:
if hasattr(self.model, "to"):
# todo: is this now redundant with self.model.to?
self.model.to(device_to)
self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = 0

View File

@ -2,11 +2,9 @@ import torch
import comfy.model_management
import comfy.utils
from comfy.language.language_types import LanguageModel
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.common import MAX_RESOLUTION
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize, _AUTO_CHAT_TEMPLATE
from comfy.nodes.package_typing import CustomNode
class EmptyCosmosLatentVideo(CustomNode):
@ -81,32 +79,4 @@ class CosmosImageToVideoLatent(CustomNode):
return (out_latent,)
class CosmosPromptUpsamplerTransformersLoader(TransformersLoader):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": ("STRING", {}),
},
}
# from https://github.com/NVIDIA/Cosmos/blob/b867572b99d08f450ddb8bcd6661d8c35bf6b967/cosmos1/models/diffusion/nemo/inference/inference_utils.py#L54
FROM_COSMOS_REPO_PROMPT_PREFIX = "Upsample the short caption to a long caption: "
class CosmosUpsamplePromptTokenize(OneShotInstructTokenize):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
},
}
def execute(self, model: LanguageModel, prompt: str, images: list[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
return super().execute(model, f"{FROM_COSMOS_REPO_PROMPT_PREFIX}{prompt}", images=None, chat_template=_AUTO_CHAT_TEMPLATE)
export_custom_nodes()

View File

@ -0,0 +1,207 @@
import re
from pathlib import Path
from typing import Optional
import torch
from comfy import model_management
from comfy.component_model.empty_init import EmptyInitOnDevice
from comfy.component_model.tensor_types import RGBImageBatch
from comfy.language.language_types import LanguageModel, ProcessorResult, LanguagePrompt, GENERATION_KWARGS_TYPE, \
TOKENS_TYPE
from comfy.model_downloader import get_or_download_huggingface_repo
from comfy.model_management import load_models_gpu
from comfy.model_patcher import ModelPatcher
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import InputTypes, ValidatedNodeResult, CustomNode
from comfy_extras.nodes.nodes_language import TransformersLoader, TransformersTokenize, OneShotInstructTokenize, \
_AUTO_CHAT_TEMPLATE
# from https://github.com/NVIDIA/Cosmos/blob/b867572b99d08f450ddb8bcd6661d8c35bf6b967/cosmos1/models/diffusion/nemo/inference/inference_utils.py#L54
COSMOS_TEXT_TO_WORLD_UPSAMPLE_TASK = "Upsample the short caption to a long caption: "
COSMOS_VIDEO_TO_WORLD_UPSAMPLE_TASK = """
Your task is to transform a given prompt into a refined and concise video description, no more than 150 words.
Focus only on the content, no filler words or descriptions on the style. Never mention things outside the video.
"""
def clean_text(text: str) -> str:
"""Clean the text by removing prefixes, suffixes, formatting markers, and normalizing whitespace."""
# Replace all variations of newlines with a space
text = text.replace("\n", " ").replace("\r", " ")
# Use a regex to find sections of the form '- **...**'
pattern = r"(- \*\*)(.*?)(\*\*)"
def replacement(match: re.Match[str]) -> str:
content = match.group(2) # The text inside - ** and **
words = re.findall(r"\w+", content)
if len(words) < 10:
# If fewer than 10 words, remove the entire '- **...**' portion
return ""
else:
# If 10 or more words, keep the entire section as it is
return match.group(0)
text = re.sub(pattern, replacement, text)
# Remove common prefixes
prefixes = ["Caption:", "#####", "####", "- ", "* ", ","]
for prefix in prefixes:
# lstrip(prefix) won't strip entire strings, but character sets.
# For more reliable prefix removal, do:
if text.startswith(prefix):
text = text[len(prefix):].lstrip()
# Remove extra spaces
text = " ".join(text.split())
# Strip any remaining leading/trailing punctuation, whitespace, and quotes
text = text.strip(' -,*:"\'"“”')
return text
class CosmosPromptUpsamplerTransformersLoader(TransformersLoader):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (["unsloth/Pixtral-12B-2409"], {}),
},
}
class Mistral12b(LanguageModel):
def __init__(self, model: ModelPatcher, ckpt_name: str):
self.model = model
self.ckpt_name = ckpt_name
@staticmethod
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "Mistral12b":
from cosmos1.models.autoregressive.configs.base.model_config import create_text_model_config # pylint: disable=import-error
from cosmos1.models.autoregressive.model import AutoRegressiveModel # pylint: disable=import-error
checkpoint_dir = get_or_download_huggingface_repo(ckpt_name)
assert checkpoint_dir is not None, f"did not successfully download {ckpt_name}"
checkpoint_dir = Path(checkpoint_dir)
model_config, tokenizer_config = create_text_model_config(
model_ckpt_path=str(checkpoint_dir / "model.pt"),
tokenizer_path=str(checkpoint_dir),
model_family="mistral",
model_size="12b",
is_instruct_model=True,
max_batch_size=1,
rope_dim="1D",
add_special_tokens=True,
max_seq_len=1024,
pytorch_rope_version="v1",
)
try:
with EmptyInitOnDevice(device=model_management.unet_offload_device()):
completion_instance_cpu = AutoRegressiveModel.build(model_config=model_config, tokenizer_config=tokenizer_config)
finally:
torch.set_default_dtype(torch.float32)
patchable_completion_instance_cpu = ModelPatcher(completion_instance_cpu, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), size=completion_instance_cpu.get_num_params() * 2, ckpt_name=ckpt_name)
return Mistral12b(patchable_completion_instance_cpu, ckpt_name=ckpt_name)
def generate(self,
tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs) -> str:
sampler = sampler or {}
prompt = tokens.get("inputs", [])
prompt = "".join(prompt)
dialogs = [[{"role": "user", "content": prompt}]]
from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion # pylint: disable=import-error
from cosmos1.models.autoregressive.model import AutoRegressiveModel # pylint: disable=import-error
load_models_gpu([self.model])
# noinspection PyTypeChecker
model: AutoRegressiveModel = self.model.model
assert isinstance(model, AutoRegressiveModel)
results = chat_completion(
model,
dialogs,
seed=seed,
max_gen_len=max_new_tokens,
temperature=sampler.get("temperature", 0.01),
top_p=sampler.get("top_p", None),
top_k=sampler.get("top_k", None),
logprobs=False,
)
upsampled_prompt = str(clean_text(results[0]["generation"]["content"]))
return upsampled_prompt
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult:
# Return prompts and image as is
return {
"inputs": [prompt],
"attention_mask": torch.ones(1, len(prompt)), # Dummy attention mask
"images": images
}
@property
def repo_id(self) -> str:
return self.ckpt_name
class CosmosText2WorldPromptUpsamplerLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (["nvidia/Cosmos-1.0-Prompt-Upsampler-12B-Text2World"], {}),
},
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("upsampler model",)
CATEGORY = "cosmos"
FUNCTION = "execute"
def execute(self, ckpt_name: str) -> tuple[LanguageModel]:
return Mistral12b.from_pretrained(ckpt_name),
class CosmosText2WorldUpsamplePromptTokenize(TransformersTokenize):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
},
}
def execute(self, model: LanguageModel, prompt: str) -> ValidatedNodeResult:
return super().execute(model, f"{COSMOS_TEXT_TO_WORLD_UPSAMPLE_TASK}{prompt}")
class CosmosVideo2WorldUpsamplePromptTokenize(OneShotInstructTokenize):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL", {}),
},
"optional": {
"images": ("IMAGE", {}),
}
}
def execute(self, model: LanguageModel, prompt: str = None, images: list[torch.Tensor] | torch.Tensor = None, chat_template: str = _AUTO_CHAT_TEMPLATE) -> ValidatedNodeResult:
return super().execute(model, COSMOS_VIDEO_TO_WORLD_UPSAMPLE_TASK, images, _AUTO_CHAT_TEMPLATE)
export_custom_nodes()

View File

@ -8,6 +8,8 @@ import torch._inductor.codecache
from torch.nn import LayerNorm
from comfy import model_management
from comfy.language.language_types import LanguageModel
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes
@ -97,7 +99,7 @@ class TorchCompileModel(CustomNode):
}
move_to_gpu = True
del compile_kwargs["mode"]
if isinstance(model, ModelPatcher):
if isinstance(model, ModelPatcher) or isinstance(model, TransformersManagedModel):
m = model.clone()
if move_to_gpu:
model_management.load_models_gpu([m])
@ -115,12 +117,19 @@ class TorchCompileModel(CustomNode):
else:
logging.warning("Encountered a model that cannot be compiled")
return model,
except OSError:
except OSError as os_error:
try:
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
except Exception:
pass
raise
raise os_error
except Exception as exc_info:
try:
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
except Exception:
pass
logging.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info)
return model,
_QUANTIZATION_STRATEGIES = [

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 MiB