Prompt upsampling, better torch.compile support for language models

This commit is contained in:
doctorpangloss 2025-03-03 18:36:47 -08:00
parent c6111fae7d
commit d82261485f
11 changed files with 316 additions and 51 deletions

View File

@ -479,6 +479,28 @@ comfyui --use-sage-attention
![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp) ![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp)
**With PyTorch Attention** **With PyTorch Attention**
## Cosmos Prompt Upsampling
The Cosmos prompt "upsampler," a fine tune of Mistral-Nemo-12b, correctly rewrites Cosmos prompts in the narrative style that NVIDIA's captioner used for the training data of Cosmos, improving generation results significantly.
Here is a comparison between a simple and "upsampled" prompt.
![prompt_upsampling_01.webp](docs/assets/prompt_upsampling_01.webp)
**A dog is playing with a ball.**
![prompt_upsampling_02.webp](docs/assets/prompt_upsampling_02.webp)
**In a sun-drenched park, a playful golden retriever bounds joyfully across the lush green grass, its tail wagging with excitement. The dog, adorned with a vibrant red collar, is captivated by a bright yellow ball, which it cradles gently in its mouth. The camera captures the dog's animated expressions, from eager anticipation to sheer delight, as it trots and leaps, showcasing its agility and enthusiasm. The scene is bathed in warm, golden-hour light, enhancing the vibrant colors of the dog's fur and the ball. The background features a serene tree line, framing the playful interaction and creating a tranquil atmosphere. The static camera angle allows for an intimate focus on the dog's joyful antics, inviting viewers to share in this heartwarming moment of pure canine happiness.**
To use the Cosmos upsampler, install the prerequisites:
```shell
uv pip install loguru pynvml
uv pip install --no-deps git+https://github.com/NVIDIA/Cosmos.git
```
Then, use the workflow embedded in the upsampled prompt by dragging and dropping the upsampled animation into your workspace.
The Cosmos upsampler ought to improve any text-to-image video generation pipeline. Use the `Video2World` upsampler nodes to download Pixtral-12b and upsample for an image to video workflow using NVIDIA's default prompt. Since Pixtral is not fine tuned, the improvement may not be significant over using another LLM.
# Custom Nodes # Custom Nodes
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory. Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.

View File

@ -0,0 +1,19 @@
import torch
import torch.overrides
import torch.utils._device
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
def __init__(self, device=None):
self.device = device
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, '__module__', None) == 'torch.nn.init':
if 'tensor' in kwargs:
return kwargs['tensor']
else:
return args[0]
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)

View File

@ -6,6 +6,7 @@ import inspect
import logging import logging
import operator import operator
import pathlib import pathlib
import weakref
from functools import reduce from functools import reduce
from typing import Optional, Any, Callable from typing import Optional, Any, Callable
@ -43,9 +44,10 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
processor: Optional[ProcessorMixin | AutoProcessor] = None processor: Optional[ProcessorMixin | AutoProcessor] = None
): ):
self._repo_id = repo_id self._repo_id = repo_id
self.model = model self._model = model
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._processor = processor self._processor = processor
self._object_patches: dict[str, Any] = {}
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values()) self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values()) self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
self.load_device = get_torch_device() self.load_device = get_torch_device()
@ -53,6 +55,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
self._config_dict = config_dict self._config_dict = config_dict
self._on_set_processor(self._processor) self._on_set_processor(self._processor)
self._model_type = "" self._model_type = ""
self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self)
if model.device != self.offload_device: if model.device != self.offload_device:
model.to(device=self.offload_device) model.to(device=self.offload_device)
@ -426,6 +429,28 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
else: else:
return f"<TransformersManagedModel for {self.model.__class__.__name__}>" return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
def clone(self) -> TransformersManagedModel:
m = copy.copy(self)
# deep copy a few objects
m._object_patches = copy.copy(self._object_patches)
return m
def add_object_patch(self, name: str, obj: Any):
# for the sake of compatibility, rewrite the name to the actual model field
if name == "diffusion_model":
name = "model"
self._object_patches[name] = obj
def get_model_object(self, name: str) -> torch.nn.Module:
if name == "diffusion_model":
name = "model"
return super().get_model_object(name)
@property
def model(self) -> PreTrainedModel | torch.nn.Module:
return self._object_patches.get("model", self._model)
class _ProgressTextStreamer(TextStreamer): class _ProgressTextStreamer(TextStreamer):
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):

View File

@ -406,7 +406,7 @@ class LoadedModel:
self._set_model(model) self._set_model(model)
@property @property
def model(self): def model(self) -> ModelManageable:
return self._model() return self._model()
def model_memory(self): def model_memory(self):

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
import torch import torch
import torch.nn import torch.nn
@ -150,6 +150,8 @@ class MemoryMeasurements:
def device(self) -> torch.device: def device(self) -> torch.device:
if isinstance(self.model, DeviceSettable): if isinstance(self.model, DeviceSettable):
return self.model.device return self.model.device
elif hasattr(self.model, "device"):
return self.model.device
else: else:
return self._device return self._device
@ -157,6 +159,8 @@ class MemoryMeasurements:
def device(self, value: torch.device): def device(self, value: torch.device):
if isinstance(self.model, DeviceSettable): if isinstance(self.model, DeviceSettable):
self.model.device = value self.model.device = value
elif hasattr(self.model, "to"):
self.model.to(value)
self._device = value self._device = value
@ -175,3 +179,9 @@ class ModelOptions(TypedDict, total=False):
disable_cfg1_optimization: NotRequired[bool] disable_cfg1_optimization: NotRequired[bool]
denoise_mask_function: NotRequired[Callable] denoise_mask_function: NotRequired[Callable]
patches: NotRequired[dict[str, list]] patches: NotRequired[dict[str, list]]
class LoadingListItem(NamedTuple):
module_size: int
name: str
module: torch.nn.Module
params: list[str]

View File

@ -39,7 +39,7 @@ from .float import stochastic_rounding
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
from .model_base import BaseModel from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem
from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -627,7 +627,8 @@ class ModelPatcher(ModelManageable):
else: else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _load_list(self):
def _load_list(self) -> list[LoadingListItem]:
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
params = [] params = []
@ -639,7 +640,7 @@ class ModelPatcher(ModelManageable):
skip = True # skip random weights in non leaf modules skip = True # skip random weights in non leaf modules
break break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((model_management.module_size(m), n, m, params)) loading.append(LoadingListItem(model_management.module_size(m), n, m, params))
return loading return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -650,13 +651,13 @@ class ModelPatcher(ModelManageable):
lowvram_counter = 0 lowvram_counter = 0
loading = self._load_list() loading = self._load_list()
load_completely = [] load_completely: list[LoadingListItem] = []
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
n = x[1] n = x.name
m = x[2] m = x.module
params = x[3] params = x.params
module_mem = x[0] module_mem = x.module_size
lowvram_weight = False lowvram_weight = False
@ -696,7 +697,7 @@ class ModelPatcher(ModelManageable):
if full_load or mem_counter + module_mem < lowvram_model_memory: if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem mem_counter += module_mem
load_completely.append((module_mem, n, m, params)) load_completely.append(LoadingListItem(module_mem, n, m, params))
if cast_weight: if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
@ -712,9 +713,9 @@ class ModelPatcher(ModelManageable):
load_completely.sort(reverse=True) load_completely.sort(reverse=True)
for x in load_completely: for x in load_completely:
n = x[1] n = x.name
m = x[2] m = x.module
params = x[3] params = x.params
if hasattr(m, "comfy_patched_weights"): if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True: if m.comfy_patched_weights == True:
continue continue
@ -726,7 +727,7 @@ class ModelPatcher(ModelManageable):
m.comfy_patched_weights = True m.comfy_patched_weights = True
for x in load_completely: for x in load_completely:
x[2].to(device_to) x.module.to(device_to)
if lowvram_counter > 0: if lowvram_counter > 0:
logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
@ -791,6 +792,8 @@ class ModelPatcher(ModelManageable):
self.backup.clear() self.backup.clear()
if device_to is not None: if device_to is not None:
if hasattr(self.model, "to"):
# todo: is this now redundant with self.model.to?
self.model.to(device_to) self.model.to(device_to)
self.model_device = device_to self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = 0 self._memory_measurements.model_loaded_weight_memory = 0

View File

@ -2,11 +2,9 @@ import torch
import comfy.model_management import comfy.model_management
import comfy.utils import comfy.utils
from comfy.language.language_types import LanguageModel
from comfy.node_helpers import export_custom_nodes from comfy.node_helpers import export_custom_nodes
from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.common import MAX_RESOLUTION
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult from comfy.nodes.package_typing import CustomNode
from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize, _AUTO_CHAT_TEMPLATE
class EmptyCosmosLatentVideo(CustomNode): class EmptyCosmosLatentVideo(CustomNode):
@ -81,32 +79,4 @@ class CosmosImageToVideoLatent(CustomNode):
return (out_latent,) return (out_latent,)
class CosmosPromptUpsamplerTransformersLoader(TransformersLoader):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": ("STRING", {}),
},
}
# from https://github.com/NVIDIA/Cosmos/blob/b867572b99d08f450ddb8bcd6661d8c35bf6b967/cosmos1/models/diffusion/nemo/inference/inference_utils.py#L54
FROM_COSMOS_REPO_PROMPT_PREFIX = "Upsample the short caption to a long caption: "
class CosmosUpsamplePromptTokenize(OneShotInstructTokenize):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
},
}
def execute(self, model: LanguageModel, prompt: str, images: list[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
return super().execute(model, f"{FROM_COSMOS_REPO_PROMPT_PREFIX}{prompt}", images=None, chat_template=_AUTO_CHAT_TEMPLATE)
export_custom_nodes() export_custom_nodes()

View File

@ -0,0 +1,207 @@
import re
from pathlib import Path
from typing import Optional
import torch
from comfy import model_management
from comfy.component_model.empty_init import EmptyInitOnDevice
from comfy.component_model.tensor_types import RGBImageBatch
from comfy.language.language_types import LanguageModel, ProcessorResult, LanguagePrompt, GENERATION_KWARGS_TYPE, \
TOKENS_TYPE
from comfy.model_downloader import get_or_download_huggingface_repo
from comfy.model_management import load_models_gpu
from comfy.model_patcher import ModelPatcher
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import InputTypes, ValidatedNodeResult, CustomNode
from comfy_extras.nodes.nodes_language import TransformersLoader, TransformersTokenize, OneShotInstructTokenize, \
_AUTO_CHAT_TEMPLATE
# from https://github.com/NVIDIA/Cosmos/blob/b867572b99d08f450ddb8bcd6661d8c35bf6b967/cosmos1/models/diffusion/nemo/inference/inference_utils.py#L54
COSMOS_TEXT_TO_WORLD_UPSAMPLE_TASK = "Upsample the short caption to a long caption: "
COSMOS_VIDEO_TO_WORLD_UPSAMPLE_TASK = """
Your task is to transform a given prompt into a refined and concise video description, no more than 150 words.
Focus only on the content, no filler words or descriptions on the style. Never mention things outside the video.
"""
def clean_text(text: str) -> str:
"""Clean the text by removing prefixes, suffixes, formatting markers, and normalizing whitespace."""
# Replace all variations of newlines with a space
text = text.replace("\n", " ").replace("\r", " ")
# Use a regex to find sections of the form '- **...**'
pattern = r"(- \*\*)(.*?)(\*\*)"
def replacement(match: re.Match[str]) -> str:
content = match.group(2) # The text inside - ** and **
words = re.findall(r"\w+", content)
if len(words) < 10:
# If fewer than 10 words, remove the entire '- **...**' portion
return ""
else:
# If 10 or more words, keep the entire section as it is
return match.group(0)
text = re.sub(pattern, replacement, text)
# Remove common prefixes
prefixes = ["Caption:", "#####", "####", "- ", "* ", ","]
for prefix in prefixes:
# lstrip(prefix) won't strip entire strings, but character sets.
# For more reliable prefix removal, do:
if text.startswith(prefix):
text = text[len(prefix):].lstrip()
# Remove extra spaces
text = " ".join(text.split())
# Strip any remaining leading/trailing punctuation, whitespace, and quotes
text = text.strip(' -,*:"\'"“”')
return text
class CosmosPromptUpsamplerTransformersLoader(TransformersLoader):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (["unsloth/Pixtral-12B-2409"], {}),
},
}
class Mistral12b(LanguageModel):
def __init__(self, model: ModelPatcher, ckpt_name: str):
self.model = model
self.ckpt_name = ckpt_name
@staticmethod
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "Mistral12b":
from cosmos1.models.autoregressive.configs.base.model_config import create_text_model_config # pylint: disable=import-error
from cosmos1.models.autoregressive.model import AutoRegressiveModel # pylint: disable=import-error
checkpoint_dir = get_or_download_huggingface_repo(ckpt_name)
assert checkpoint_dir is not None, f"did not successfully download {ckpt_name}"
checkpoint_dir = Path(checkpoint_dir)
model_config, tokenizer_config = create_text_model_config(
model_ckpt_path=str(checkpoint_dir / "model.pt"),
tokenizer_path=str(checkpoint_dir),
model_family="mistral",
model_size="12b",
is_instruct_model=True,
max_batch_size=1,
rope_dim="1D",
add_special_tokens=True,
max_seq_len=1024,
pytorch_rope_version="v1",
)
try:
with EmptyInitOnDevice(device=model_management.unet_offload_device()):
completion_instance_cpu = AutoRegressiveModel.build(model_config=model_config, tokenizer_config=tokenizer_config)
finally:
torch.set_default_dtype(torch.float32)
patchable_completion_instance_cpu = ModelPatcher(completion_instance_cpu, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), size=completion_instance_cpu.get_num_params() * 2, ckpt_name=ckpt_name)
return Mistral12b(patchable_completion_instance_cpu, ckpt_name=ckpt_name)
def generate(self,
tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs) -> str:
sampler = sampler or {}
prompt = tokens.get("inputs", [])
prompt = "".join(prompt)
dialogs = [[{"role": "user", "content": prompt}]]
from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion # pylint: disable=import-error
from cosmos1.models.autoregressive.model import AutoRegressiveModel # pylint: disable=import-error
load_models_gpu([self.model])
# noinspection PyTypeChecker
model: AutoRegressiveModel = self.model.model
assert isinstance(model, AutoRegressiveModel)
results = chat_completion(
model,
dialogs,
seed=seed,
max_gen_len=max_new_tokens,
temperature=sampler.get("temperature", 0.01),
top_p=sampler.get("top_p", None),
top_k=sampler.get("top_k", None),
logprobs=False,
)
upsampled_prompt = str(clean_text(results[0]["generation"]["content"]))
return upsampled_prompt
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult:
# Return prompts and image as is
return {
"inputs": [prompt],
"attention_mask": torch.ones(1, len(prompt)), # Dummy attention mask
"images": images
}
@property
def repo_id(self) -> str:
return self.ckpt_name
class CosmosText2WorldPromptUpsamplerLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (["nvidia/Cosmos-1.0-Prompt-Upsampler-12B-Text2World"], {}),
},
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("upsampler model",)
CATEGORY = "cosmos"
FUNCTION = "execute"
def execute(self, ckpt_name: str) -> tuple[LanguageModel]:
return Mistral12b.from_pretrained(ckpt_name),
class CosmosText2WorldUpsamplePromptTokenize(TransformersTokenize):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
},
}
def execute(self, model: LanguageModel, prompt: str) -> ValidatedNodeResult:
return super().execute(model, f"{COSMOS_TEXT_TO_WORLD_UPSAMPLE_TASK}{prompt}")
class CosmosVideo2WorldUpsamplePromptTokenize(OneShotInstructTokenize):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL", {}),
},
"optional": {
"images": ("IMAGE", {}),
}
}
def execute(self, model: LanguageModel, prompt: str = None, images: list[torch.Tensor] | torch.Tensor = None, chat_template: str = _AUTO_CHAT_TEMPLATE) -> ValidatedNodeResult:
return super().execute(model, COSMOS_VIDEO_TO_WORLD_UPSAMPLE_TASK, images, _AUTO_CHAT_TEMPLATE)
export_custom_nodes()

View File

@ -8,6 +8,8 @@ import torch._inductor.codecache
from torch.nn import LayerNorm from torch.nn import LayerNorm
from comfy import model_management from comfy import model_management
from comfy.language.language_types import LanguageModel
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes from comfy.nodes.package_typing import CustomNode, InputTypes
@ -97,7 +99,7 @@ class TorchCompileModel(CustomNode):
} }
move_to_gpu = True move_to_gpu = True
del compile_kwargs["mode"] del compile_kwargs["mode"]
if isinstance(model, ModelPatcher): if isinstance(model, ModelPatcher) or isinstance(model, TransformersManagedModel):
m = model.clone() m = model.clone()
if move_to_gpu: if move_to_gpu:
model_management.load_models_gpu([m]) model_management.load_models_gpu([m])
@ -115,12 +117,19 @@ class TorchCompileModel(CustomNode):
else: else:
logging.warning("Encountered a model that cannot be compiled") logging.warning("Encountered a model that cannot be compiled")
return model, return model,
except OSError: except OSError as os_error:
try: try:
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
except Exception: except Exception:
pass pass
raise raise os_error
except Exception as exc_info:
try:
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
except Exception:
pass
logging.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info)
return model,
_QUANTIZATION_STRATEGIES = [ _QUANTIZATION_STRATEGIES = [

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 MiB