mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 20:30:25 +08:00
Improving MLLM/VLLM support and fixing bugs
- fix #29 str(model) no longer raises exceptions like with HyVideoModelLoader - don't try to format CUDA tensors because that can sometimes raise exceptions - cudaAllocAsync has been disabled for now due to 2.6.0 bugs - improve florence2 support - add support for paligemma 2. This requires the fix for transformers that is currently staged in another repo, install with `uv pip install --no-deps "transformers@git+https://github.com/zucchini-nlp/transformers.git#branch=paligemma-fix-kwargs"` - triton has been updated - fix missing __init__.py files
This commit is contained in:
parent
dcac115f68
commit
6ab1aa1e8a
@ -49,7 +49,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
cm_group = parser.add_mutually_exclusive_group()
|
cm_group = parser.add_mutually_exclusive_group()
|
||||||
cm_group.add_argument("--cuda-malloc", action="store_true",
|
cm_group.add_argument("--cuda-malloc", action="store_true",
|
||||||
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
cm_group.add_argument("--disable-cuda-malloc", action="store_true", default=True, help="Disable cudaMallocAsync.")
|
||||||
|
|
||||||
fp_group = parser.add_mutually_exclusive_group()
|
fp_group = parser.add_mutually_exclusive_group()
|
||||||
fp_group.add_argument("--force-fp32", action="store_true",
|
fp_group.add_argument("--force-fp32", action="store_true",
|
||||||
|
|||||||
@ -142,7 +142,7 @@ class Configuration(dict):
|
|||||||
self.disable_auto_launch: bool = False
|
self.disable_auto_launch: bool = False
|
||||||
self.cuda_device: Optional[int] = None
|
self.cuda_device: Optional[int] = None
|
||||||
self.cuda_malloc: bool = True
|
self.cuda_malloc: bool = True
|
||||||
self.disable_cuda_malloc: bool = False
|
self.disable_cuda_malloc: bool = True
|
||||||
self.dont_upcast_attention: bool = False
|
self.dont_upcast_attention: bool = False
|
||||||
self.force_upcast_attention: bool = False
|
self.force_upcast_attention: bool = False
|
||||||
self.force_fp32: bool = False
|
self.force_fp32: bool = False
|
||||||
|
|||||||
@ -290,8 +290,10 @@ def format_value(x) -> FormattedValue:
|
|||||||
return None
|
return None
|
||||||
elif isinstance(x, (int, float, bool, str)):
|
elif isinstance(x, (int, float, bool, str)):
|
||||||
return x
|
return x
|
||||||
else:
|
elif isinstance(x, dict) and not any(isinstance(v, torch.Tensor) for v in x.values()):
|
||||||
return str(x)
|
return str(x)
|
||||||
|
else:
|
||||||
|
return str(x.__class__)
|
||||||
|
|
||||||
|
|
||||||
def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
|
def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
|
||||||
|
|||||||
@ -63,7 +63,8 @@ except Exception:
|
|||||||
|
|
||||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
||||||
|
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
||||||
|
|
||||||
def _fix_pytorch_240():
|
def _fix_pytorch_240():
|
||||||
"""Fixes pytorch 2.4.0"""
|
"""Fixes pytorch 2.4.0"""
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
@ -29,7 +30,8 @@ from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
|
|||||||
_OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2']
|
_OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2']
|
||||||
|
|
||||||
# should be added if the expectation is that this model emits special tokens
|
# should be added if the expectation is that this model emits special tokens
|
||||||
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2'}
|
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'}
|
||||||
|
|
||||||
|
|
||||||
class TransformersManagedModel(ModelManageable, LanguageModel):
|
class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -214,7 +216,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
|||||||
|
|
||||||
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
||||||
|
|
||||||
with seed_for_block(seed):
|
try:
|
||||||
|
import triton # pylint: disable=import-error
|
||||||
|
has_triton = True
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
has_triton = False
|
||||||
|
|
||||||
|
with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext():
|
||||||
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
||||||
inputs.pop("attention_mask")
|
inputs.pop("attention_mask")
|
||||||
output_ids = transformers_model.generate(
|
output_ids = transformers_model.generate(
|
||||||
@ -310,7 +318,6 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
|||||||
def model_dtype(self) -> torch.dtype:
|
def model_dtype(self) -> torch.dtype:
|
||||||
return self.model.dtype
|
return self.model.dtype
|
||||||
|
|
||||||
|
|
||||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||||
return self.model.to(device=device_to)
|
return self.model.to(device=device_to)
|
||||||
|
|
||||||
@ -344,11 +351,12 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
|||||||
if isinstance(images, list):
|
if isinstance(images, list):
|
||||||
images = torch.stack(images, dim=0)
|
images = torch.stack(images, dim=0)
|
||||||
if images is not None:
|
if images is not None:
|
||||||
# PIL it for the sake of simplicity
|
|
||||||
image_sizes = [(image.shape[-2], image.shape[-3]) for image in images]
|
image_sizes = [(image.shape[-2], image.shape[-3]) for image in images]
|
||||||
else:
|
else:
|
||||||
image_sizes = []
|
image_sizes = []
|
||||||
images = []
|
# todo: what is the best choice for this?
|
||||||
|
# probably select a size that related to the vision tower?
|
||||||
|
images = torch.zeroes((0, 0, 0, 3))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(tokenizer, "apply_chat_template"):
|
if hasattr(tokenizer, "apply_chat_template"):
|
||||||
@ -383,8 +391,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
|||||||
else:
|
else:
|
||||||
if hasattr(self.processor, "to"):
|
if hasattr(self.processor, "to"):
|
||||||
self.processor.to(device=self.load_device)
|
self.processor.to(device=self.load_device)
|
||||||
|
# convert tuple to list from images.unbind() for paligemma workaround
|
||||||
batch_feature: BatchFeature = self.processor(text=[prompt], images=images.unbind(), return_tensors="pt", padding=True)
|
batch_feature: BatchFeature = self.processor(text=[prompt], images=list(images.unbind()) if images is not None and len(images) > 0 else None, return_tensors="pt", padding=True)
|
||||||
if hasattr(self.processor, "to"):
|
if hasattr(self.processor, "to"):
|
||||||
self.processor.to(device=self.offload_device)
|
self.processor.to(device=self.offload_device)
|
||||||
assert "input_ids" in batch_feature
|
assert "input_ids" in batch_feature
|
||||||
|
|||||||
@ -452,6 +452,10 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
|
|||||||
'roborovski/superprompt-v1',
|
'roborovski/superprompt-v1',
|
||||||
'Qwen/Qwen2-VL-7B-Instruct',
|
'Qwen/Qwen2-VL-7B-Instruct',
|
||||||
'microsoft/Florence-2-large-ft',
|
'microsoft/Florence-2-large-ft',
|
||||||
|
'google/paligemma2-10b-pt-896',
|
||||||
|
'google/paligemma2-28b-pt-896',
|
||||||
|
'google/paligemma-3b-ft-refcoco-seg-896',
|
||||||
|
'microsoft/phi-4',
|
||||||
}
|
}
|
||||||
|
|
||||||
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
|
|||||||
@ -551,7 +551,7 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
|
|||||||
_load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
_load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||||
to_load = list(map(str, models))
|
to_load = list(map(str, models))
|
||||||
span.set_attribute("models", to_load)
|
span.set_attribute("models", to_load)
|
||||||
logger.info(f"Loaded {to_load}")
|
logger.debug(f"Loaded {to_load}")
|
||||||
|
|
||||||
|
|
||||||
def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
|
def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
|
||||||
@ -627,6 +627,8 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
|
|||||||
|
|
||||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
|
logger.debug(f"Loaded {loaded_model}")
|
||||||
|
|
||||||
|
|
||||||
span = get_current_span()
|
span = get_current_span()
|
||||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||||
|
|||||||
@ -265,9 +265,6 @@ class ModelPatcher(ModelManageable):
|
|||||||
def lowvram_patch_counter(self):
|
def lowvram_patch_counter(self):
|
||||||
return self._memory_measurements.lowvram_patch_counter
|
return self._memory_measurements.lowvram_patch_counter
|
||||||
|
|
||||||
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
|
||||||
self.model.current_weight_patches_uuid = None
|
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
return self.size
|
return self.size
|
||||||
@ -845,7 +842,10 @@ class ModelPatcher(ModelManageable):
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if hasattr(self.model, "operations"):
|
if hasattr(self.model, "operations"):
|
||||||
operations_str = self.model.operations.__name__
|
if hasattr(self.model.operations, "__name__"):
|
||||||
|
operations_str = self.model.operations.__name__
|
||||||
|
else:
|
||||||
|
operations_str = str(self.model.operations)
|
||||||
else:
|
else:
|
||||||
operations_str = None
|
operations_str = None
|
||||||
info_str = f"model_dtype={self.model_dtype()} device={self.model_device} size={naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)} operations={operations_str}"
|
info_str = f"model_dtype={self.model_dtype()} device={self.model_device} size={naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)} operations={operations_str}"
|
||||||
|
|||||||
@ -63,6 +63,10 @@ NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
|
|||||||
|
|
||||||
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
||||||
|
|
||||||
|
# numpy seeds must be between 0 and 2**32 - 1
|
||||||
|
Seed = ("INT", {"default": 0, "min": 0, "max": 2**32 - 1})
|
||||||
|
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]
|
||||||
|
|
||||||
|
|
||||||
class HiddenSpec(TypedDict, total=True):
|
class HiddenSpec(TypedDict, total=True):
|
||||||
prompt: Literal["PROMPT"]
|
prompt: Literal["PROMPT"]
|
||||||
|
|||||||
0
comfy_extras/chainner_models/__init__.py
Normal file
0
comfy_extras/chainner_models/__init__.py
Normal file
0
comfy_extras/constants/__init__.py
Normal file
0
comfy_extras/constants/__init__.py
Normal file
@ -23,6 +23,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy.nodes.common import MAX_RESOLUTION
|
from comfy.nodes.common import MAX_RESOLUTION
|
||||||
|
from comfy.nodes.package_typing import Seed
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
import logging as log
|
import logging as log
|
||||||
# Sync with theoritical limit from Comfy base
|
# Sync with theoritical limit from Comfy base
|
||||||
@ -73,7 +74,7 @@ class INPUT(Enum):
|
|||||||
def MASK():
|
def MASK():
|
||||||
return ("MASK",)
|
return ("MASK",)
|
||||||
def SEED(default=0):
|
def SEED(default=0):
|
||||||
return ("INT", dict(default=default, min=0, max=0xffffffffffffffff))
|
return Seed
|
||||||
def RESOLUTION(default=512, min=64, max=MAX_RESOLUTION, step=64):
|
def RESOLUTION(default=512, min=64, max=MAX_RESOLUTION, step=64):
|
||||||
return ("INT", dict(default=default, min=min, max=max, step=step))
|
return ("INT", dict(default=default, min=min, max=max, step=step))
|
||||||
def INT(default=0, min=0, max=MAX_RESOLUTION, step=1):
|
def INT(default=0, min=0, max=MAX_RESOLUTION, step=1):
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from comfy.cmd import latent_preview
|
|||||||
import torch
|
import torch
|
||||||
from comfy import utils
|
from comfy import utils
|
||||||
from comfy import node_helpers
|
from comfy import node_helpers
|
||||||
|
from comfy.nodes.package_typing import Seed
|
||||||
from comfy.samplers import KSAMPLER
|
from comfy.samplers import KSAMPLER
|
||||||
|
|
||||||
|
|
||||||
@ -458,7 +459,7 @@ class SamplerCustom:
|
|||||||
return {"required":
|
return {"required":
|
||||||
{"model": ("MODEL",),
|
{"model": ("MODEL",),
|
||||||
"add_noise": ("BOOLEAN", {"default": True}),
|
"add_noise": ("BOOLEAN", {"default": True}),
|
||||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"noise_seed": Seed,
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||||
"positive": ("CONDITIONING", ),
|
"positive": ("CONDITIONING", ),
|
||||||
"negative": ("CONDITIONING", ),
|
"negative": ("CONDITIONING", ),
|
||||||
@ -610,7 +611,7 @@ class RandomNoise(DisableNoise):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":{
|
return {"required":{
|
||||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"noise_seed": Seed,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -138,7 +138,7 @@ class Florence2PostProcess(CustomNode):
|
|||||||
return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])),
|
return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])),
|
||||||
|
|
||||||
|
|
||||||
class Florence2OutputToPolygon(CustomNode):
|
class Florence2OutputToMask(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
@ -166,6 +166,6 @@ NODE_CLASS_MAPPINGS = {}
|
|||||||
for cls in (
|
for cls in (
|
||||||
Florence2PostProcess,
|
Florence2PostProcess,
|
||||||
Florence2TaskTokenize,
|
Florence2TaskTokenize,
|
||||||
Florence2OutputToPolygon
|
Florence2OutputToMask
|
||||||
):
|
):
|
||||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||||
@ -9,7 +9,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
|
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
|
||||||
from comfy.nodes.package_typing import CustomNode
|
from comfy.nodes.package_typing import CustomNode, Seed
|
||||||
from comfy.utils import pil2tensor, tensor2pil
|
from comfy.utils import pil2tensor, tensor2pil
|
||||||
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
|
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
|
||||||
from comfy_extras.nodes.nodes_mask import MaskToImage
|
from comfy_extras.nodes.nodes_mask import MaskToImage
|
||||||
@ -46,7 +46,7 @@ class IdeogramGenerate(CustomNode):
|
|||||||
"api_key": ("STRING", {"default": ""}),
|
"api_key": ("STRING", {"default": ""}),
|
||||||
"negative_prompt": ("STRING", {"multiline": True}),
|
"negative_prompt": ("STRING", {"multiline": True}),
|
||||||
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
||||||
"seed": ("INT", {"default": 0}),
|
"seed": Seed,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWA
|
|||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||||
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
|
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed
|
||||||
|
|
||||||
_AUTO_CHAT_TEMPLATE = "default"
|
_AUTO_CHAT_TEMPLATE = "default"
|
||||||
|
|
||||||
@ -339,7 +339,7 @@ class TransformersGenerate(CustomNode):
|
|||||||
"tokens": (TOKENS_TYPE_NAME, {}),
|
"tokens": (TOKENS_TYPE_NAME, {}),
|
||||||
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
||||||
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
|
"seed": Seed,
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"sampler": (GENERATION_KWARGS_TYPE_NAME, {}),
|
"sampler": (GENERATION_KWARGS_TYPE_NAME, {}),
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import torch
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.component_model.tensor_types import Latent
|
from comfy.component_model.tensor_types import Latent
|
||||||
|
from comfy.nodes.package_typing import Seed
|
||||||
from .nodes_post_processing import gaussian_kernel
|
from .nodes_post_processing import gaussian_kernel
|
||||||
|
|
||||||
|
|
||||||
@ -168,7 +169,7 @@ class LatentAddNoiseChannels:
|
|||||||
"required": {
|
"required": {
|
||||||
"samples": ("LATENT",),
|
"samples": ("LATENT",),
|
||||||
"std_dev": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"std_dev": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": Seed,
|
||||||
"slice_i": ("INT", {"default": 0, "min": -16, "max": 16}),
|
"slice_i": ("INT", {"default": 0, "min": -16, "max": 16}),
|
||||||
"slice_j": ("INT", {"default": 16, "min": -16, "max": 16}),
|
"slice_j": ("INT", {"default": 16, "min": -16, "max": 16}),
|
||||||
}
|
}
|
||||||
|
|||||||
250
comfy_extras/nodes/nodes_paligemma.py
Normal file
250
comfy_extras/nodes/nodes_paligemma.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
import functools
|
||||||
|
import re
|
||||||
|
from importlib.resources import as_file, files
|
||||||
|
from typing import TypedDict, NamedTuple
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
|
|
||||||
|
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
|
||||||
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||||
|
|
||||||
|
_MODEL_PATH = 'vae-oid.npz'
|
||||||
|
|
||||||
|
_SEGMENT_DETECT_RE = re.compile(
|
||||||
|
r'(.*?)' +
|
||||||
|
r'<loc(\d{4})>' * 4 + r'\s*' +
|
||||||
|
'(?:%s)?' % (r'<seg(\d{3})>' * 16) +
|
||||||
|
r'\s*([^;<>]+)? ?(?:; )?',
|
||||||
|
)
|
||||||
|
|
||||||
|
PALIGEMMA_OUTPUT_NAME = "PALIGEMMA_OUTPUT"
|
||||||
|
|
||||||
|
|
||||||
|
class BoundingBox(NamedTuple):
|
||||||
|
x1: int
|
||||||
|
y1: int
|
||||||
|
x2: int
|
||||||
|
y2: int
|
||||||
|
|
||||||
|
|
||||||
|
PaligemmaMask = Float[np.ndarray, "height width"]
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractedPaligemmaSegmented(TypedDict):
|
||||||
|
content: str
|
||||||
|
xyxy: BoundingBox
|
||||||
|
mask: PaligemmaMask | None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractedPaligemmaContentOnly(TypedDict):
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
ExtractedPaligemmaObject = ExtractedPaligemmaSegmented | ExtractedPaligemmaContentOnly
|
||||||
|
PostProcessResult = list[ExtractedPaligemmaObject]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_params(checkpoint):
|
||||||
|
"""Converts PyTorch checkpoint to Flax params."""
|
||||||
|
|
||||||
|
def transp(kernel):
|
||||||
|
return np.transpose(kernel, (2, 3, 1, 0))
|
||||||
|
|
||||||
|
def conv(name):
|
||||||
|
return {
|
||||||
|
'bias': checkpoint[name + '.bias'],
|
||||||
|
'kernel': transp(checkpoint[name + '.weight']),
|
||||||
|
}
|
||||||
|
|
||||||
|
def resblock(name):
|
||||||
|
return {
|
||||||
|
'Conv_0': conv(name + '.0'),
|
||||||
|
'Conv_1': conv(name + '.2'),
|
||||||
|
'Conv_2': conv(name + '.4'),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'_embeddings': checkpoint['_vq_vae._embedding'],
|
||||||
|
'Conv_0': conv('decoder.0'),
|
||||||
|
'ResBlock_0': resblock('decoder.2.net'),
|
||||||
|
'ResBlock_1': resblock('decoder.3.net'),
|
||||||
|
'ConvTranspose_0': conv('decoder.4'),
|
||||||
|
'ConvTranspose_1': conv('decoder.6'),
|
||||||
|
'ConvTranspose_2': conv('decoder.8'),
|
||||||
|
'ConvTranspose_3': conv('decoder.10'),
|
||||||
|
'Conv_1': conv('decoder.12'),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
|
||||||
|
batch_size, num_tokens = codebook_indices.shape
|
||||||
|
assert num_tokens == 16, codebook_indices.shape
|
||||||
|
unused_num_embeddings, embedding_dim = embeddings.shape
|
||||||
|
|
||||||
|
encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
|
||||||
|
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
|
||||||
|
return encodings
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def _get_reconstruct_masks():
|
||||||
|
"""Reconstructs masks from codebook indices.
|
||||||
|
Returns:
|
||||||
|
A function that expects indices shaped `[B, 16]` of dtype int32, each
|
||||||
|
ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
|
||||||
|
`[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
features: int
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
original_x = x
|
||||||
|
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
|
||||||
|
x = nn.relu(x)
|
||||||
|
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
|
||||||
|
x = nn.relu(x)
|
||||||
|
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
|
||||||
|
return x + original_x
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""Upscales quantized vectors to mask."""
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
num_res_blocks = 2
|
||||||
|
dim = 128
|
||||||
|
num_upsample_layers = 4
|
||||||
|
|
||||||
|
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
|
||||||
|
x = nn.relu(x)
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
x = ResBlock(features=dim)(x)
|
||||||
|
|
||||||
|
for _ in range(num_upsample_layers):
|
||||||
|
x = nn.ConvTranspose(
|
||||||
|
features=dim,
|
||||||
|
kernel_size=(4, 4),
|
||||||
|
strides=(2, 2),
|
||||||
|
padding=2,
|
||||||
|
transpose_kernel=True,
|
||||||
|
)(x)
|
||||||
|
x = nn.relu(x)
|
||||||
|
dim //= 2
|
||||||
|
|
||||||
|
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def reconstruct_masks(codebook_indices):
|
||||||
|
quantized = _quantized_values_from_codebook_indices(
|
||||||
|
codebook_indices, params['_embeddings']
|
||||||
|
)
|
||||||
|
return Decoder().apply({'params': params}, quantized)
|
||||||
|
|
||||||
|
with as_file(files("comfy_extras.paligemma") / _MODEL_PATH) as f:
|
||||||
|
params = _get_params(dict(np.load(f)))
|
||||||
|
|
||||||
|
return jax.jit(reconstruct_masks, backend='cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def extract_objs(text, width, height, unique_labels=False) -> PostProcessResult:
|
||||||
|
"""Returns objs for a string with "<loc>" and "<seg>" tokens."""
|
||||||
|
objs: list[ExtractedPaligemmaObject] = []
|
||||||
|
seen = set()
|
||||||
|
while text:
|
||||||
|
m = _SEGMENT_DETECT_RE.match(text)
|
||||||
|
if not m:
|
||||||
|
break
|
||||||
|
gs = list(m.groups())
|
||||||
|
before = gs.pop(0)
|
||||||
|
name = gs.pop()
|
||||||
|
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
|
||||||
|
|
||||||
|
y1, x1, y2, x2 = map(round, (y1 * height, x1 * width, y2 * height, x2 * width))
|
||||||
|
seg_indices = gs[4:20]
|
||||||
|
if seg_indices[0] is None:
|
||||||
|
mask = None
|
||||||
|
else:
|
||||||
|
seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
|
||||||
|
m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
|
||||||
|
m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
|
||||||
|
m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
|
||||||
|
mask = np.zeros([height, width])
|
||||||
|
if y2 > y1 and x2 > x1:
|
||||||
|
mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
|
||||||
|
|
||||||
|
content = m.group()
|
||||||
|
if before:
|
||||||
|
objs.append(dict(content=before))
|
||||||
|
content = content[len(before):]
|
||||||
|
while unique_labels and name in seen:
|
||||||
|
name = (name or '') + "'"
|
||||||
|
seen.add(name)
|
||||||
|
paligemma_output_obj: ExtractedPaligemmaObject = {'content': content, 'xyxy': BoundingBox(x1, y1, x2, y2), 'mask': mask, 'name': name}
|
||||||
|
objs.append(paligemma_output_obj)
|
||||||
|
text = text[len(before) + len(content):]
|
||||||
|
|
||||||
|
if text:
|
||||||
|
objs.append(dict(content=text))
|
||||||
|
|
||||||
|
return [obj for obj in objs if obj["content"] != '<eos>']
|
||||||
|
|
||||||
|
|
||||||
|
class PaligemmaPostProcess(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"generated_text": ("STRING", {"forceInput": True}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"images": ("IMAGE", {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "language"
|
||||||
|
RETURN_TYPES = (PALIGEMMA_OUTPUT_NAME,)
|
||||||
|
RETURN_NAMES = ("paligemma output",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
def execute(self, generated_text: str = "", task: str = "", images: RGBImageBatch = None) -> tuple[PostProcessResult]:
|
||||||
|
return extract_objs(generated_text, images.shape[-2], images.shape[-3]),
|
||||||
|
|
||||||
|
|
||||||
|
class PaligemmaOutputToMask(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"paligemma_output": (PALIGEMMA_OUTPUT_NAME, {"forceInput": True}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "language"
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
RETURN_NAMES = ("paligemma output",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
def execute(self, paligemma_output: PostProcessResult) -> tuple[MaskBatch]:
|
||||||
|
masks = [torch.from_numpy(p["mask"]) for p in paligemma_output if "mask" in p]
|
||||||
|
if len(masks) == 0:
|
||||||
|
return torch.zeroes((0, 0, 0)),
|
||||||
|
return torch.stack(masks, dim=0),
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {}
|
||||||
|
for cls in (
|
||||||
|
PaligemmaOutputToMask,
|
||||||
|
PaligemmaPostProcess,
|
||||||
|
):
|
||||||
|
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||||
0
comfy_extras/paligemma/__init__.py
Normal file
0
comfy_extras/paligemma/__init__.py
Normal file
BIN
comfy_extras/paligemma/vae-oid.npz
Normal file
BIN
comfy_extras/paligemma/vae-oid.npz
Normal file
Binary file not shown.
@ -1,4 +1,5 @@
|
|||||||
triton ;platform_system == 'Linux'
|
triton ;platform_system == 'Linux'
|
||||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12'
|
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp313-cp313-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.13'
|
||||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11'
|
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12'
|
||||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10'
|
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11'
|
||||||
|
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10'
|
||||||
@ -70,4 +70,6 @@ pebble>=5.0.7
|
|||||||
openai
|
openai
|
||||||
anthropic
|
anthropic
|
||||||
humanize
|
humanize
|
||||||
lightning
|
lightning
|
||||||
|
flax
|
||||||
|
jax
|
||||||
16
tests/issues/__test_29_fix_str_in_model.py
Normal file
16
tests/issues/__test_29_fix_str_in_model.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import torch.nn
|
||||||
|
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
|
class HasOperationsNoName(torch.nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.operations = object()
|
||||||
|
if hasattr(self.operations, "__name__"):
|
||||||
|
delattr(self.operations, "__name__")
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_model_patcher():
|
||||||
|
model_patcher = ModelPatcher(HasOperationsNoName(), torch.device('cpu'), torch.device('cpu'))
|
||||||
|
assert str(model_patcher) is not None
|
||||||
Loading…
Reference in New Issue
Block a user