mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Prompt upsampling, better torch.compile support for language models
This commit is contained in:
parent
c6111fae7d
commit
d82261485f
22
README.md
22
README.md
@ -479,6 +479,28 @@ comfyui --use-sage-attention
|
||||

|
||||
**With PyTorch Attention**
|
||||
|
||||
## Cosmos Prompt Upsampling
|
||||
|
||||
The Cosmos prompt "upsampler," a fine tune of Mistral-Nemo-12b, correctly rewrites Cosmos prompts in the narrative style that NVIDIA's captioner used for the training data of Cosmos, improving generation results significantly.
|
||||
|
||||
Here is a comparison between a simple and "upsampled" prompt.
|
||||
|
||||

|
||||
**A dog is playing with a ball.**
|
||||
|
||||

|
||||
**In a sun-drenched park, a playful golden retriever bounds joyfully across the lush green grass, its tail wagging with excitement. The dog, adorned with a vibrant red collar, is captivated by a bright yellow ball, which it cradles gently in its mouth. The camera captures the dog's animated expressions, from eager anticipation to sheer delight, as it trots and leaps, showcasing its agility and enthusiasm. The scene is bathed in warm, golden-hour light, enhancing the vibrant colors of the dog's fur and the ball. The background features a serene tree line, framing the playful interaction and creating a tranquil atmosphere. The static camera angle allows for an intimate focus on the dog's joyful antics, inviting viewers to share in this heartwarming moment of pure canine happiness.**
|
||||
|
||||
To use the Cosmos upsampler, install the prerequisites:
|
||||
|
||||
```shell
|
||||
uv pip install loguru pynvml
|
||||
uv pip install --no-deps git+https://github.com/NVIDIA/Cosmos.git
|
||||
```
|
||||
Then, use the workflow embedded in the upsampled prompt by dragging and dropping the upsampled animation into your workspace.
|
||||
|
||||
The Cosmos upsampler ought to improve any text-to-image video generation pipeline. Use the `Video2World` upsampler nodes to download Pixtral-12b and upsample for an image to video workflow using NVIDIA's default prompt. Since Pixtral is not fine tuned, the improvement may not be significant over using another LLM.
|
||||
|
||||
# Custom Nodes
|
||||
|
||||
Custom Nodes can be added to ComfyUI by copying and pasting Python files into your `./custom_nodes` directory.
|
||||
|
||||
19
comfy/component_model/empty_init.py
Normal file
19
comfy/component_model/empty_init.py
Normal file
@ -0,0 +1,19 @@
|
||||
import torch
|
||||
import torch.overrides
|
||||
import torch.utils._device
|
||||
|
||||
|
||||
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
|
||||
def __init__(self, device=None):
|
||||
self.device = device
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if getattr(func, '__module__', None) == 'torch.nn.init':
|
||||
if 'tensor' in kwargs:
|
||||
return kwargs['tensor']
|
||||
else:
|
||||
return args[0]
|
||||
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
|
||||
kwargs['device'] = self.device
|
||||
return func(*args, **kwargs)
|
||||
@ -6,6 +6,7 @@ import inspect
|
||||
import logging
|
||||
import operator
|
||||
import pathlib
|
||||
import weakref
|
||||
from functools import reduce
|
||||
from typing import Optional, Any, Callable
|
||||
|
||||
@ -43,9 +44,10 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
processor: Optional[ProcessorMixin | AutoProcessor] = None
|
||||
):
|
||||
self._repo_id = repo_id
|
||||
self.model = model
|
||||
self._model = model
|
||||
self._tokenizer = tokenizer
|
||||
self._processor = processor
|
||||
self._object_patches: dict[str, Any] = {}
|
||||
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
|
||||
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
|
||||
self.load_device = get_torch_device()
|
||||
@ -53,6 +55,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
self._config_dict = config_dict
|
||||
self._on_set_processor(self._processor)
|
||||
self._model_type = ""
|
||||
self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self)
|
||||
if model.device != self.offload_device:
|
||||
model.to(device=self.offload_device)
|
||||
|
||||
@ -426,6 +429,28 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
else:
|
||||
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
|
||||
|
||||
def clone(self) -> TransformersManagedModel:
|
||||
m = copy.copy(self)
|
||||
# deep copy a few objects
|
||||
m._object_patches = copy.copy(self._object_patches)
|
||||
return m
|
||||
|
||||
def add_object_patch(self, name: str, obj: Any):
|
||||
# for the sake of compatibility, rewrite the name to the actual model field
|
||||
if name == "diffusion_model":
|
||||
name = "model"
|
||||
|
||||
self._object_patches[name] = obj
|
||||
|
||||
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||
if name == "diffusion_model":
|
||||
name = "model"
|
||||
return super().get_model_object(name)
|
||||
|
||||
@property
|
||||
def model(self) -> PreTrainedModel | torch.nn.Module:
|
||||
return self._object_patches.get("model", self._model)
|
||||
|
||||
|
||||
class _ProgressTextStreamer(TextStreamer):
|
||||
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||
|
||||
@ -406,7 +406,7 @@ class LoadedModel:
|
||||
self._set_model(model)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
def model(self) -> ModelManageable:
|
||||
return self._model()
|
||||
|
||||
def model_memory(self):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any
|
||||
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
@ -150,6 +150,8 @@ class MemoryMeasurements:
|
||||
def device(self) -> torch.device:
|
||||
if isinstance(self.model, DeviceSettable):
|
||||
return self.model.device
|
||||
elif hasattr(self.model, "device"):
|
||||
return self.model.device
|
||||
else:
|
||||
return self._device
|
||||
|
||||
@ -157,6 +159,8 @@ class MemoryMeasurements:
|
||||
def device(self, value: torch.device):
|
||||
if isinstance(self.model, DeviceSettable):
|
||||
self.model.device = value
|
||||
elif hasattr(self.model, "to"):
|
||||
self.model.to(value)
|
||||
self._device = value
|
||||
|
||||
|
||||
@ -175,3 +179,9 @@ class ModelOptions(TypedDict, total=False):
|
||||
disable_cfg1_optimization: NotRequired[bool]
|
||||
denoise_mask_function: NotRequired[Callable]
|
||||
patches: NotRequired[dict[str, list]]
|
||||
|
||||
class LoadingListItem(NamedTuple):
|
||||
module_size: int
|
||||
name: str
|
||||
module: torch.nn.Module
|
||||
params: list[str]
|
||||
@ -39,7 +39,7 @@ from .float import stochastic_rounding
|
||||
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
|
||||
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
|
||||
from .model_base import BaseModel
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem
|
||||
from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -627,7 +627,8 @@ class ModelPatcher(ModelManageable):
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def _load_list(self):
|
||||
|
||||
def _load_list(self) -> list[LoadingListItem]:
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
params = []
|
||||
@ -639,7 +640,7 @@ class ModelPatcher(ModelManageable):
|
||||
skip = True # skip random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
loading.append((model_management.module_size(m), n, m, params))
|
||||
loading.append(LoadingListItem(model_management.module_size(m), n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@ -650,13 +651,13 @@ class ModelPatcher(ModelManageable):
|
||||
lowvram_counter = 0
|
||||
loading = self._load_list()
|
||||
|
||||
load_completely = []
|
||||
load_completely: list[LoadingListItem] = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
module_mem = x[0]
|
||||
n = x.name
|
||||
m = x.module
|
||||
params = x.params
|
||||
module_mem = x.module_size
|
||||
|
||||
lowvram_weight = False
|
||||
|
||||
@ -696,7 +697,7 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
load_completely.append(LoadingListItem(module_mem, n, m, params))
|
||||
|
||||
if cast_weight:
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
@ -712,9 +713,9 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
load_completely.sort(reverse=True)
|
||||
for x in load_completely:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
n = x.name
|
||||
m = x.module
|
||||
params = x.params
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
if m.comfy_patched_weights == True:
|
||||
continue
|
||||
@ -726,7 +727,7 @@ class ModelPatcher(ModelManageable):
|
||||
m.comfy_patched_weights = True
|
||||
|
||||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
x.module.to(device_to)
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
@ -791,6 +792,8 @@ class ModelPatcher(ModelManageable):
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
if hasattr(self.model, "to"):
|
||||
# todo: is this now redundant with self.model.to?
|
||||
self.model.to(device_to)
|
||||
self.model_device = device_to
|
||||
self._memory_measurements.model_loaded_weight_memory = 0
|
||||
|
||||
@ -2,11 +2,9 @@ import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy.language.language_types import LanguageModel
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize, _AUTO_CHAT_TEMPLATE
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
|
||||
class EmptyCosmosLatentVideo(CustomNode):
|
||||
@ -81,32 +79,4 @@ class CosmosImageToVideoLatent(CustomNode):
|
||||
return (out_latent,)
|
||||
|
||||
|
||||
class CosmosPromptUpsamplerTransformersLoader(TransformersLoader):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": ("STRING", {}),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# from https://github.com/NVIDIA/Cosmos/blob/b867572b99d08f450ddb8bcd6661d8c35bf6b967/cosmos1/models/diffusion/nemo/inference/inference_utils.py#L54
|
||||
FROM_COSMOS_REPO_PROMPT_PREFIX = "Upsample the short caption to a long caption: "
|
||||
|
||||
|
||||
class CosmosUpsamplePromptTokenize(OneShotInstructTokenize):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||
},
|
||||
}
|
||||
|
||||
def execute(self, model: LanguageModel, prompt: str, images: list[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
|
||||
return super().execute(model, f"{FROM_COSMOS_REPO_PROMPT_PREFIX}{prompt}", images=None, chat_template=_AUTO_CHAT_TEMPLATE)
|
||||
|
||||
|
||||
export_custom_nodes()
|
||||
|
||||
207
comfy_extras/nodes/nodes_cosmos_upsampling.py
Normal file
207
comfy_extras/nodes/nodes_cosmos_upsampling.py
Normal file
@ -0,0 +1,207 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.component_model.empty_init import EmptyInitOnDevice
|
||||
from comfy.component_model.tensor_types import RGBImageBatch
|
||||
from comfy.language.language_types import LanguageModel, ProcessorResult, LanguagePrompt, GENERATION_KWARGS_TYPE, \
|
||||
TOKENS_TYPE
|
||||
from comfy.model_downloader import get_or_download_huggingface_repo
|
||||
from comfy.model_management import load_models_gpu
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import InputTypes, ValidatedNodeResult, CustomNode
|
||||
from comfy_extras.nodes.nodes_language import TransformersLoader, TransformersTokenize, OneShotInstructTokenize, \
|
||||
_AUTO_CHAT_TEMPLATE
|
||||
|
||||
# from https://github.com/NVIDIA/Cosmos/blob/b867572b99d08f450ddb8bcd6661d8c35bf6b967/cosmos1/models/diffusion/nemo/inference/inference_utils.py#L54
|
||||
COSMOS_TEXT_TO_WORLD_UPSAMPLE_TASK = "Upsample the short caption to a long caption: "
|
||||
COSMOS_VIDEO_TO_WORLD_UPSAMPLE_TASK = """
|
||||
Your task is to transform a given prompt into a refined and concise video description, no more than 150 words.
|
||||
Focus only on the content, no filler words or descriptions on the style. Never mention things outside the video.
|
||||
"""
|
||||
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""Clean the text by removing prefixes, suffixes, formatting markers, and normalizing whitespace."""
|
||||
# Replace all variations of newlines with a space
|
||||
text = text.replace("\n", " ").replace("\r", " ")
|
||||
|
||||
# Use a regex to find sections of the form '- **...**'
|
||||
pattern = r"(- \*\*)(.*?)(\*\*)"
|
||||
|
||||
def replacement(match: re.Match[str]) -> str:
|
||||
content = match.group(2) # The text inside - ** and **
|
||||
words = re.findall(r"\w+", content)
|
||||
if len(words) < 10:
|
||||
# If fewer than 10 words, remove the entire '- **...**' portion
|
||||
return ""
|
||||
else:
|
||||
# If 10 or more words, keep the entire section as it is
|
||||
return match.group(0)
|
||||
|
||||
text = re.sub(pattern, replacement, text)
|
||||
|
||||
# Remove common prefixes
|
||||
prefixes = ["Caption:", "#####", "####", "- ", "* ", ","]
|
||||
for prefix in prefixes:
|
||||
# lstrip(prefix) won't strip entire strings, but character sets.
|
||||
# For more reliable prefix removal, do:
|
||||
if text.startswith(prefix):
|
||||
text = text[len(prefix):].lstrip()
|
||||
|
||||
# Remove extra spaces
|
||||
text = " ".join(text.split())
|
||||
|
||||
# Strip any remaining leading/trailing punctuation, whitespace, and quotes
|
||||
text = text.strip(' -,*:"\'"“”')
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class CosmosPromptUpsamplerTransformersLoader(TransformersLoader):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (["unsloth/Pixtral-12B-2409"], {}),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Mistral12b(LanguageModel):
|
||||
def __init__(self, model: ModelPatcher, ckpt_name: str):
|
||||
self.model = model
|
||||
self.ckpt_name = ckpt_name
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "Mistral12b":
|
||||
from cosmos1.models.autoregressive.configs.base.model_config import create_text_model_config # pylint: disable=import-error
|
||||
from cosmos1.models.autoregressive.model import AutoRegressiveModel # pylint: disable=import-error
|
||||
checkpoint_dir = get_or_download_huggingface_repo(ckpt_name)
|
||||
assert checkpoint_dir is not None, f"did not successfully download {ckpt_name}"
|
||||
checkpoint_dir = Path(checkpoint_dir)
|
||||
model_config, tokenizer_config = create_text_model_config(
|
||||
model_ckpt_path=str(checkpoint_dir / "model.pt"),
|
||||
tokenizer_path=str(checkpoint_dir),
|
||||
model_family="mistral",
|
||||
model_size="12b",
|
||||
is_instruct_model=True,
|
||||
max_batch_size=1,
|
||||
rope_dim="1D",
|
||||
add_special_tokens=True,
|
||||
max_seq_len=1024,
|
||||
pytorch_rope_version="v1",
|
||||
)
|
||||
|
||||
try:
|
||||
with EmptyInitOnDevice(device=model_management.unet_offload_device()):
|
||||
completion_instance_cpu = AutoRegressiveModel.build(model_config=model_config, tokenizer_config=tokenizer_config)
|
||||
finally:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
patchable_completion_instance_cpu = ModelPatcher(completion_instance_cpu, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), size=completion_instance_cpu.get_num_params() * 2, ckpt_name=ckpt_name)
|
||||
return Mistral12b(patchable_completion_instance_cpu, ckpt_name=ckpt_name)
|
||||
|
||||
def generate(self,
|
||||
tokens: TOKENS_TYPE = None,
|
||||
max_new_tokens: int = 512,
|
||||
repetition_penalty: float = 0.0,
|
||||
seed: int = 0,
|
||||
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||
*args,
|
||||
**kwargs) -> str:
|
||||
sampler = sampler or {}
|
||||
prompt = tokens.get("inputs", [])
|
||||
prompt = "".join(prompt)
|
||||
|
||||
dialogs = [[{"role": "user", "content": prompt}]]
|
||||
|
||||
from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion # pylint: disable=import-error
|
||||
from cosmos1.models.autoregressive.model import AutoRegressiveModel # pylint: disable=import-error
|
||||
|
||||
load_models_gpu([self.model])
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
model: AutoRegressiveModel = self.model.model
|
||||
assert isinstance(model, AutoRegressiveModel)
|
||||
|
||||
results = chat_completion(
|
||||
model,
|
||||
dialogs,
|
||||
seed=seed,
|
||||
max_gen_len=max_new_tokens,
|
||||
temperature=sampler.get("temperature", 0.01),
|
||||
top_p=sampler.get("top_p", None),
|
||||
top_k=sampler.get("top_k", None),
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
upsampled_prompt = str(clean_text(results[0]["generation"]["content"]))
|
||||
return upsampled_prompt
|
||||
|
||||
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult:
|
||||
# Return prompts and image as is
|
||||
return {
|
||||
"inputs": [prompt],
|
||||
"attention_mask": torch.ones(1, len(prompt)), # Dummy attention mask
|
||||
"images": images
|
||||
}
|
||||
|
||||
@property
|
||||
def repo_id(self) -> str:
|
||||
return self.ckpt_name
|
||||
|
||||
|
||||
class CosmosText2WorldPromptUpsamplerLoader(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (["nvidia/Cosmos-1.0-Prompt-Upsampler-12B-Text2World"], {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("upsampler model",)
|
||||
CATEGORY = "cosmos"
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, ckpt_name: str) -> tuple[LanguageModel]:
|
||||
return Mistral12b.from_pretrained(ckpt_name),
|
||||
|
||||
|
||||
class CosmosText2WorldUpsamplePromptTokenize(TransformersTokenize):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||
},
|
||||
}
|
||||
|
||||
def execute(self, model: LanguageModel, prompt: str) -> ValidatedNodeResult:
|
||||
return super().execute(model, f"{COSMOS_TEXT_TO_WORLD_UPSAMPLE_TASK}{prompt}")
|
||||
|
||||
|
||||
class CosmosVideo2WorldUpsamplePromptTokenize(OneShotInstructTokenize):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {}),
|
||||
},
|
||||
"optional": {
|
||||
"images": ("IMAGE", {}),
|
||||
}
|
||||
}
|
||||
|
||||
def execute(self, model: LanguageModel, prompt: str = None, images: list[torch.Tensor] | torch.Tensor = None, chat_template: str = _AUTO_CHAT_TEMPLATE) -> ValidatedNodeResult:
|
||||
return super().execute(model, COSMOS_VIDEO_TO_WORLD_UPSAMPLE_TASK, images, _AUTO_CHAT_TEMPLATE)
|
||||
|
||||
|
||||
export_custom_nodes()
|
||||
@ -8,6 +8,8 @@ import torch._inductor.codecache
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.language.language_types import LanguageModel
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
|
||||
@ -97,7 +99,7 @@ class TorchCompileModel(CustomNode):
|
||||
}
|
||||
move_to_gpu = True
|
||||
del compile_kwargs["mode"]
|
||||
if isinstance(model, ModelPatcher):
|
||||
if isinstance(model, ModelPatcher) or isinstance(model, TransformersManagedModel):
|
||||
m = model.clone()
|
||||
if move_to_gpu:
|
||||
model_management.load_models_gpu([m])
|
||||
@ -115,12 +117,19 @@ class TorchCompileModel(CustomNode):
|
||||
else:
|
||||
logging.warning("Encountered a model that cannot be compiled")
|
||||
return model,
|
||||
except OSError:
|
||||
except OSError as os_error:
|
||||
try:
|
||||
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
raise os_error
|
||||
except Exception as exc_info:
|
||||
try:
|
||||
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
|
||||
except Exception:
|
||||
pass
|
||||
logging.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info)
|
||||
return model,
|
||||
|
||||
|
||||
_QUANTIZATION_STRATEGIES = [
|
||||
|
||||
BIN
docs/assets/prompt_upsampling_01.webp
Normal file
BIN
docs/assets/prompt_upsampling_01.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.8 MiB |
BIN
docs/assets/prompt_upsampling_02.webp
Normal file
BIN
docs/assets/prompt_upsampling_02.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.8 MiB |
Loading…
Reference in New Issue
Block a user