Improving MLLM/VLLM support and fixing bugs

- fix #29 str(model) no longer raises exceptions like with
   HyVideoModelLoader
 - don't try to format CUDA tensors because that can sometimes raise
   exceptions
 - cudaAllocAsync has been disabled for now due to 2.6.0 bugs
 - improve florence2 support
 - add support for paligemma 2. This requires the fix for transformers
   that is currently staged in another repo, install with
   `uv pip install --no-deps "transformers@git+https://github.com/zucchini-nlp/transformers.git#branch=paligemma-fix-kwargs"`
 - triton has been updated
 - fix missing __init__.py files
This commit is contained in:
doctorpangloss 2025-02-05 14:02:28 -08:00
parent dcac115f68
commit 6ab1aa1e8a
23 changed files with 323 additions and 30 deletions

View File

@ -49,7 +49,7 @@ def _create_parser() -> EnhancedConfigArgParser:
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true",
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", default=True, help="Disable cudaMallocAsync.")
fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true",

View File

@ -142,7 +142,7 @@ class Configuration(dict):
self.disable_auto_launch: bool = False
self.cuda_device: Optional[int] = None
self.cuda_malloc: bool = True
self.disable_cuda_malloc: bool = False
self.disable_cuda_malloc: bool = True
self.dont_upcast_attention: bool = False
self.force_upcast_attention: bool = False
self.force_fp32: bool = False

View File

@ -290,8 +290,10 @@ def format_value(x) -> FormattedValue:
return None
elif isinstance(x, (int, float, bool, str)):
return x
else:
elif isinstance(x, dict) and not any(isinstance(v, torch.Tensor) for v in x.values()):
return str(x)
else:
return str(x.__class__)
def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:

View File

@ -63,7 +63,8 @@ except Exception:
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
def _fix_pytorch_240():
"""Fixes pytorch 2.4.0"""

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import contextlib
import copy
import inspect
import logging
@ -29,7 +30,8 @@ from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
_OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2']
# should be added if the expectation is that this model emits special tokens
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2'}
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'}
class TransformersManagedModel(ModelManageable, LanguageModel):
def __init__(
@ -214,7 +216,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
with seed_for_block(seed):
try:
import triton # pylint: disable=import-error
has_triton = True
except (ImportError, ModuleNotFoundError):
has_triton = False
with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext():
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
inputs.pop("attention_mask")
output_ids = transformers_model.generate(
@ -310,7 +318,6 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
def model_dtype(self) -> torch.dtype:
return self.model.dtype
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
return self.model.to(device=device_to)
@ -344,11 +351,12 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
if isinstance(images, list):
images = torch.stack(images, dim=0)
if images is not None:
# PIL it for the sake of simplicity
image_sizes = [(image.shape[-2], image.shape[-3]) for image in images]
else:
image_sizes = []
images = []
# todo: what is the best choice for this?
# probably select a size that related to the vision tower?
images = torch.zeroes((0, 0, 0, 3))
try:
if hasattr(tokenizer, "apply_chat_template"):
@ -383,8 +391,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
else:
if hasattr(self.processor, "to"):
self.processor.to(device=self.load_device)
batch_feature: BatchFeature = self.processor(text=[prompt], images=images.unbind(), return_tensors="pt", padding=True)
# convert tuple to list from images.unbind() for paligemma workaround
batch_feature: BatchFeature = self.processor(text=[prompt], images=list(images.unbind()) if images is not None and len(images) > 0 else None, return_tensors="pt", padding=True)
if hasattr(self.processor, "to"):
self.processor.to(device=self.offload_device)
assert "input_ids" in batch_feature

View File

@ -452,6 +452,10 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
'roborovski/superprompt-v1',
'Qwen/Qwen2-VL-7B-Instruct',
'microsoft/Florence-2-large-ft',
'google/paligemma2-10b-pt-896',
'google/paligemma2-28b-pt-896',
'google/paligemma-3b-ft-refcoco-seg-896',
'microsoft/phi-4',
}
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -551,7 +551,7 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
_load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
to_load = list(map(str, models))
span.set_attribute("models", to_load)
logger.info(f"Loaded {to_load}")
logger.debug(f"Loaded {to_load}")
def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
@ -627,6 +627,8 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
logger.debug(f"Loaded {loaded_model}")
span = get_current_span()
span.set_attribute("models_to_load", list(map(str, models_to_load)))

View File

@ -265,9 +265,6 @@ class ModelPatcher(ModelManageable):
def lowvram_patch_counter(self):
return self._memory_measurements.lowvram_patch_counter
if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None
def model_size(self):
if self.size > 0:
return self.size
@ -845,7 +842,10 @@ class ModelPatcher(ModelManageable):
def __str__(self):
if hasattr(self.model, "operations"):
if hasattr(self.model.operations, "__name__"):
operations_str = self.model.operations.__name__
else:
operations_str = str(self.model.operations)
else:
operations_str = None
info_str = f"model_dtype={self.model_dtype()} device={self.model_device} size={naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)} operations={operations_str}"

View File

@ -63,6 +63,10 @@ NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
# numpy seeds must be between 0 and 2**32 - 1
Seed = ("INT", {"default": 0, "min": 0, "max": 2**32 - 1})
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]
class HiddenSpec(TypedDict, total=True):
prompt: Literal["PROMPT"]

View File

View File

View File

@ -23,6 +23,7 @@ import numpy as np
import torch
from comfy.nodes.common import MAX_RESOLUTION
from comfy.nodes.package_typing import Seed
from comfy.utils import ProgressBar
import logging as log
# Sync with theoritical limit from Comfy base
@ -73,7 +74,7 @@ class INPUT(Enum):
def MASK():
return ("MASK",)
def SEED(default=0):
return ("INT", dict(default=default, min=0, max=0xffffffffffffffff))
return Seed
def RESOLUTION(default=512, min=64, max=MAX_RESOLUTION, step=64):
return ("INT", dict(default=default, min=min, max=max, step=step))
def INT(default=0, min=0, max=MAX_RESOLUTION, step=1):

View File

@ -8,6 +8,7 @@ from comfy.cmd import latent_preview
import torch
from comfy import utils
from comfy import node_helpers
from comfy.nodes.package_typing import Seed
from comfy.samplers import KSAMPLER
@ -458,7 +459,7 @@ class SamplerCustom:
return {"required":
{"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"noise_seed": Seed,
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
@ -610,7 +611,7 @@ class RandomNoise(DisableNoise):
@classmethod
def INPUT_TYPES(s):
return {"required":{
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"noise_seed": Seed,
}
}

View File

@ -138,7 +138,7 @@ class Florence2PostProcess(CustomNode):
return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])),
class Florence2OutputToPolygon(CustomNode):
class Florence2OutputToMask(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
@ -166,6 +166,6 @@ NODE_CLASS_MAPPINGS = {}
for cls in (
Florence2PostProcess,
Florence2TaskTokenize,
Florence2OutputToPolygon
Florence2OutputToMask
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -9,7 +9,7 @@ import torch
from PIL import Image
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
from comfy.nodes.package_typing import CustomNode
from comfy.nodes.package_typing import CustomNode, Seed
from comfy.utils import pil2tensor, tensor2pil
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
from comfy_extras.nodes.nodes_mask import MaskToImage
@ -46,7 +46,7 @@ class IdeogramGenerate(CustomNode):
"api_key": ("STRING", {"default": ""}),
"negative_prompt": ("STRING", {"multiline": True}),
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
"seed": ("INT", {"default": 0}),
"seed": Seed,
}
}

View File

@ -20,7 +20,7 @@ from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWA
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed
_AUTO_CHAT_TEMPLATE = "default"
@ -339,7 +339,7 @@ class TransformersGenerate(CustomNode):
"tokens": (TOKENS_TYPE_NAME, {}),
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
"seed": Seed,
},
"optional": {
"sampler": (GENERATION_KWARGS_TYPE_NAME, {}),

View File

@ -3,6 +3,7 @@ import torch
import comfy.utils
from comfy.component_model.tensor_types import Latent
from comfy.nodes.package_typing import Seed
from .nodes_post_processing import gaussian_kernel
@ -168,7 +169,7 @@ class LatentAddNoiseChannels:
"required": {
"samples": ("LATENT",),
"std_dev": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"seed": Seed,
"slice_i": ("INT", {"default": 0, "min": -16, "max": 16}),
"slice_j": ("INT", {"default": 16, "min": -16, "max": 16}),
}

View File

@ -0,0 +1,250 @@
import functools
import re
from importlib.resources import as_file, files
from typing import TypedDict, NamedTuple
import PIL.Image
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import torch
from jaxtyping import Float
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
from comfy.nodes.package_typing import CustomNode, InputTypes
_MODEL_PATH = 'vae-oid.npz'
_SEGMENT_DETECT_RE = re.compile(
r'(.*?)' +
r'<loc(\d{4})>' * 4 + r'\s*' +
'(?:%s)?' % (r'<seg(\d{3})>' * 16) +
r'\s*([^;<>]+)? ?(?:; )?',
)
PALIGEMMA_OUTPUT_NAME = "PALIGEMMA_OUTPUT"
class BoundingBox(NamedTuple):
x1: int
y1: int
x2: int
y2: int
PaligemmaMask = Float[np.ndarray, "height width"]
class ExtractedPaligemmaSegmented(TypedDict):
content: str
xyxy: BoundingBox
mask: PaligemmaMask | None
name: str
class ExtractedPaligemmaContentOnly(TypedDict):
content: str
ExtractedPaligemmaObject = ExtractedPaligemmaSegmented | ExtractedPaligemmaContentOnly
PostProcessResult = list[ExtractedPaligemmaObject]
def _get_params(checkpoint):
"""Converts PyTorch checkpoint to Flax params."""
def transp(kernel):
return np.transpose(kernel, (2, 3, 1, 0))
def conv(name):
return {
'bias': checkpoint[name + '.bias'],
'kernel': transp(checkpoint[name + '.weight']),
}
def resblock(name):
return {
'Conv_0': conv(name + '.0'),
'Conv_1': conv(name + '.2'),
'Conv_2': conv(name + '.4'),
}
return {
'_embeddings': checkpoint['_vq_vae._embedding'],
'Conv_0': conv('decoder.0'),
'ResBlock_0': resblock('decoder.2.net'),
'ResBlock_1': resblock('decoder.3.net'),
'ConvTranspose_0': conv('decoder.4'),
'ConvTranspose_1': conv('decoder.6'),
'ConvTranspose_2': conv('decoder.8'),
'ConvTranspose_3': conv('decoder.10'),
'Conv_1': conv('decoder.12'),
}
def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
batch_size, num_tokens = codebook_indices.shape
assert num_tokens == 16, codebook_indices.shape
unused_num_embeddings, embedding_dim = embeddings.shape
encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
return encodings
@functools.cache
def _get_reconstruct_masks():
"""Reconstructs masks from codebook indices.
Returns:
A function that expects indices shaped `[B, 16]` of dtype int32, each
ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
`[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
"""
class ResBlock(nn.Module):
features: int
@nn.compact
def __call__(self, x):
original_x = x
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
x = nn.relu(x)
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
x = nn.relu(x)
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
return x + original_x
class Decoder(nn.Module):
"""Upscales quantized vectors to mask."""
@nn.compact
def __call__(self, x):
num_res_blocks = 2
dim = 128
num_upsample_layers = 4
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
x = nn.relu(x)
for _ in range(num_res_blocks):
x = ResBlock(features=dim)(x)
for _ in range(num_upsample_layers):
x = nn.ConvTranspose(
features=dim,
kernel_size=(4, 4),
strides=(2, 2),
padding=2,
transpose_kernel=True,
)(x)
x = nn.relu(x)
dim //= 2
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
return x
def reconstruct_masks(codebook_indices):
quantized = _quantized_values_from_codebook_indices(
codebook_indices, params['_embeddings']
)
return Decoder().apply({'params': params}, quantized)
with as_file(files("comfy_extras.paligemma") / _MODEL_PATH) as f:
params = _get_params(dict(np.load(f)))
return jax.jit(reconstruct_masks, backend='cpu')
def extract_objs(text, width, height, unique_labels=False) -> PostProcessResult:
"""Returns objs for a string with "<loc>" and "<seg>" tokens."""
objs: list[ExtractedPaligemmaObject] = []
seen = set()
while text:
m = _SEGMENT_DETECT_RE.match(text)
if not m:
break
gs = list(m.groups())
before = gs.pop(0)
name = gs.pop()
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
y1, x1, y2, x2 = map(round, (y1 * height, x1 * width, y2 * height, x2 * width))
seg_indices = gs[4:20]
if seg_indices[0] is None:
mask = None
else:
seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
mask = np.zeros([height, width])
if y2 > y1 and x2 > x1:
mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
content = m.group()
if before:
objs.append(dict(content=before))
content = content[len(before):]
while unique_labels and name in seen:
name = (name or '') + "'"
seen.add(name)
paligemma_output_obj: ExtractedPaligemmaObject = {'content': content, 'xyxy': BoundingBox(x1, y1, x2, y2), 'mask': mask, 'name': name}
objs.append(paligemma_output_obj)
text = text[len(before) + len(content):]
if text:
objs.append(dict(content=text))
return [obj for obj in objs if obj["content"] != '<eos>']
class PaligemmaPostProcess(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"generated_text": ("STRING", {"forceInput": True}),
},
"optional": {
"images": ("IMAGE", {}),
}
}
CATEGORY = "language"
RETURN_TYPES = (PALIGEMMA_OUTPUT_NAME,)
RETURN_NAMES = ("paligemma output",)
FUNCTION = "execute"
def execute(self, generated_text: str = "", task: str = "", images: RGBImageBatch = None) -> tuple[PostProcessResult]:
return extract_objs(generated_text, images.shape[-2], images.shape[-3]),
class PaligemmaOutputToMask(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"paligemma_output": (PALIGEMMA_OUTPUT_NAME, {"forceInput": True}),
},
}
CATEGORY = "language"
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("paligemma output",)
FUNCTION = "execute"
def execute(self, paligemma_output: PostProcessResult) -> tuple[MaskBatch]:
masks = [torch.from_numpy(p["mask"]) for p in paligemma_output if "mask" in p]
if len(masks) == 0:
return torch.zeroes((0, 0, 0)),
return torch.stack(masks, dim=0),
NODE_CLASS_MAPPINGS = {}
for cls in (
PaligemmaOutputToMask,
PaligemmaPostProcess,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

Binary file not shown.

View File

@ -1,4 +1,5 @@
triton ;platform_system == 'Linux'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp313-cp313-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.13'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10'

View File

@ -71,3 +71,5 @@ openai
anthropic
humanize
lightning
flax
jax

View File

@ -0,0 +1,16 @@
import torch.nn
from comfy.model_patcher import ModelPatcher
class HasOperationsNoName(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.operations = object()
if hasattr(self.operations, "__name__"):
delattr(self.operations, "__name__")
def test_str_model_patcher():
model_patcher = ModelPatcher(HasOperationsNoName(), torch.device('cpu'), torch.device('cpu'))
assert str(model_patcher) is not None