mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
Improving MLLM/VLLM support and fixing bugs
- fix #29 str(model) no longer raises exceptions like with HyVideoModelLoader - don't try to format CUDA tensors because that can sometimes raise exceptions - cudaAllocAsync has been disabled for now due to 2.6.0 bugs - improve florence2 support - add support for paligemma 2. This requires the fix for transformers that is currently staged in another repo, install with `uv pip install --no-deps "transformers@git+https://github.com/zucchini-nlp/transformers.git#branch=paligemma-fix-kwargs"` - triton has been updated - fix missing __init__.py files
This commit is contained in:
parent
dcac115f68
commit
6ab1aa1e8a
@ -49,7 +49,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true",
|
||||
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", default=True, help="Disable cudaMallocAsync.")
|
||||
|
||||
fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true",
|
||||
|
||||
@ -142,7 +142,7 @@ class Configuration(dict):
|
||||
self.disable_auto_launch: bool = False
|
||||
self.cuda_device: Optional[int] = None
|
||||
self.cuda_malloc: bool = True
|
||||
self.disable_cuda_malloc: bool = False
|
||||
self.disable_cuda_malloc: bool = True
|
||||
self.dont_upcast_attention: bool = False
|
||||
self.force_upcast_attention: bool = False
|
||||
self.force_fp32: bool = False
|
||||
|
||||
@ -290,8 +290,10 @@ def format_value(x) -> FormattedValue:
|
||||
return None
|
||||
elif isinstance(x, (int, float, bool, str)):
|
||||
return x
|
||||
else:
|
||||
elif isinstance(x, dict) and not any(isinstance(v, torch.Tensor) for v in x.values()):
|
||||
return str(x)
|
||||
else:
|
||||
return str(x.__class__)
|
||||
|
||||
|
||||
def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
|
||||
|
||||
@ -63,7 +63,8 @@ except Exception:
|
||||
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
||||
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
||||
|
||||
def _fix_pytorch_240():
|
||||
"""Fixes pytorch 2.4.0"""
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
@ -29,7 +30,8 @@ from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
|
||||
_OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2']
|
||||
|
||||
# should be added if the expectation is that this model emits special tokens
|
||||
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2'}
|
||||
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'}
|
||||
|
||||
|
||||
class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
def __init__(
|
||||
@ -214,7 +216,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
|
||||
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
||||
|
||||
with seed_for_block(seed):
|
||||
try:
|
||||
import triton # pylint: disable=import-error
|
||||
has_triton = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_triton = False
|
||||
|
||||
with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext():
|
||||
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
||||
inputs.pop("attention_mask")
|
||||
output_ids = transformers_model.generate(
|
||||
@ -310,7 +318,6 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
return self.model.dtype
|
||||
|
||||
|
||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||
return self.model.to(device=device_to)
|
||||
|
||||
@ -344,11 +351,12 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
if isinstance(images, list):
|
||||
images = torch.stack(images, dim=0)
|
||||
if images is not None:
|
||||
# PIL it for the sake of simplicity
|
||||
image_sizes = [(image.shape[-2], image.shape[-3]) for image in images]
|
||||
else:
|
||||
image_sizes = []
|
||||
images = []
|
||||
# todo: what is the best choice for this?
|
||||
# probably select a size that related to the vision tower?
|
||||
images = torch.zeroes((0, 0, 0, 3))
|
||||
|
||||
try:
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
@ -383,8 +391,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
else:
|
||||
if hasattr(self.processor, "to"):
|
||||
self.processor.to(device=self.load_device)
|
||||
|
||||
batch_feature: BatchFeature = self.processor(text=[prompt], images=images.unbind(), return_tensors="pt", padding=True)
|
||||
# convert tuple to list from images.unbind() for paligemma workaround
|
||||
batch_feature: BatchFeature = self.processor(text=[prompt], images=list(images.unbind()) if images is not None and len(images) > 0 else None, return_tensors="pt", padding=True)
|
||||
if hasattr(self.processor, "to"):
|
||||
self.processor.to(device=self.offload_device)
|
||||
assert "input_ids" in batch_feature
|
||||
|
||||
@ -452,6 +452,10 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
|
||||
'roborovski/superprompt-v1',
|
||||
'Qwen/Qwen2-VL-7B-Instruct',
|
||||
'microsoft/Florence-2-large-ft',
|
||||
'google/paligemma2-10b-pt-896',
|
||||
'google/paligemma2-28b-pt-896',
|
||||
'google/paligemma-3b-ft-refcoco-seg-896',
|
||||
'microsoft/phi-4',
|
||||
}
|
||||
|
||||
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
|
||||
@ -551,7 +551,7 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
|
||||
_load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
to_load = list(map(str, models))
|
||||
span.set_attribute("models", to_load)
|
||||
logger.info(f"Loaded {to_load}")
|
||||
logger.debug(f"Loaded {to_load}")
|
||||
|
||||
|
||||
def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
|
||||
@ -627,6 +627,8 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
|
||||
|
||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
logger.debug(f"Loaded {loaded_model}")
|
||||
|
||||
|
||||
span = get_current_span()
|
||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||
|
||||
@ -265,9 +265,6 @@ class ModelPatcher(ModelManageable):
|
||||
def lowvram_patch_counter(self):
|
||||
return self._memory_measurements.lowvram_patch_counter
|
||||
|
||||
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
||||
self.model.current_weight_patches_uuid = None
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@ -845,7 +842,10 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
def __str__(self):
|
||||
if hasattr(self.model, "operations"):
|
||||
operations_str = self.model.operations.__name__
|
||||
if hasattr(self.model.operations, "__name__"):
|
||||
operations_str = self.model.operations.__name__
|
||||
else:
|
||||
operations_str = str(self.model.operations)
|
||||
else:
|
||||
operations_str = None
|
||||
info_str = f"model_dtype={self.model_dtype()} device={self.model_device} size={naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)} operations={operations_str}"
|
||||
|
||||
@ -63,6 +63,10 @@ NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
|
||||
|
||||
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
||||
|
||||
# numpy seeds must be between 0 and 2**32 - 1
|
||||
Seed = ("INT", {"default": 0, "min": 0, "max": 2**32 - 1})
|
||||
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]
|
||||
|
||||
|
||||
class HiddenSpec(TypedDict, total=True):
|
||||
prompt: Literal["PROMPT"]
|
||||
|
||||
0
comfy_extras/chainner_models/__init__.py
Normal file
0
comfy_extras/chainner_models/__init__.py
Normal file
0
comfy_extras/constants/__init__.py
Normal file
0
comfy_extras/constants/__init__.py
Normal file
@ -23,6 +23,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
from comfy.nodes.package_typing import Seed
|
||||
from comfy.utils import ProgressBar
|
||||
import logging as log
|
||||
# Sync with theoritical limit from Comfy base
|
||||
@ -73,7 +74,7 @@ class INPUT(Enum):
|
||||
def MASK():
|
||||
return ("MASK",)
|
||||
def SEED(default=0):
|
||||
return ("INT", dict(default=default, min=0, max=0xffffffffffffffff))
|
||||
return Seed
|
||||
def RESOLUTION(default=512, min=64, max=MAX_RESOLUTION, step=64):
|
||||
return ("INT", dict(default=default, min=min, max=max, step=step))
|
||||
def INT(default=0, min=0, max=MAX_RESOLUTION, step=1):
|
||||
|
||||
@ -8,6 +8,7 @@ from comfy.cmd import latent_preview
|
||||
import torch
|
||||
from comfy import utils
|
||||
from comfy import node_helpers
|
||||
from comfy.nodes.package_typing import Seed
|
||||
from comfy.samplers import KSAMPLER
|
||||
|
||||
|
||||
@ -458,7 +459,7 @@ class SamplerCustom:
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"add_noise": ("BOOLEAN", {"default": True}),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"noise_seed": Seed,
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
@ -610,7 +611,7 @@ class RandomNoise(DisableNoise):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":{
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"noise_seed": Seed,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -138,7 +138,7 @@ class Florence2PostProcess(CustomNode):
|
||||
return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])),
|
||||
|
||||
|
||||
class Florence2OutputToPolygon(CustomNode):
|
||||
class Florence2OutputToMask(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
@ -166,6 +166,6 @@ NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
Florence2PostProcess,
|
||||
Florence2TaskTokenize,
|
||||
Florence2OutputToPolygon
|
||||
Florence2OutputToMask
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
from comfy.nodes.package_typing import CustomNode, Seed
|
||||
from comfy.utils import pil2tensor, tensor2pil
|
||||
from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS
|
||||
from comfy_extras.nodes.nodes_mask import MaskToImage
|
||||
@ -46,7 +46,7 @@ class IdeogramGenerate(CustomNode):
|
||||
"api_key": ("STRING", {"default": ""}),
|
||||
"negative_prompt": ("STRING", {"multiline": True}),
|
||||
"num_images": ("INT", {"default": 1, "min": 1, "max": 8}),
|
||||
"seed": ("INT", {"default": 0}),
|
||||
"seed": Seed,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWA
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed
|
||||
|
||||
_AUTO_CHAT_TEMPLATE = "default"
|
||||
|
||||
@ -339,7 +339,7 @@ class TransformersGenerate(CustomNode):
|
||||
"tokens": (TOKENS_TYPE_NAME, {}),
|
||||
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
||||
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
|
||||
"seed": Seed,
|
||||
},
|
||||
"optional": {
|
||||
"sampler": (GENERATION_KWARGS_TYPE_NAME, {}),
|
||||
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
|
||||
import comfy.utils
|
||||
from comfy.component_model.tensor_types import Latent
|
||||
from comfy.nodes.package_typing import Seed
|
||||
from .nodes_post_processing import gaussian_kernel
|
||||
|
||||
|
||||
@ -168,7 +169,7 @@ class LatentAddNoiseChannels:
|
||||
"required": {
|
||||
"samples": ("LATENT",),
|
||||
"std_dev": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"seed": Seed,
|
||||
"slice_i": ("INT", {"default": 0, "min": -16, "max": 16}),
|
||||
"slice_j": ("INT", {"default": 16, "min": -16, "max": 16}),
|
||||
}
|
||||
|
||||
250
comfy_extras/nodes/nodes_paligemma.py
Normal file
250
comfy_extras/nodes/nodes_paligemma.py
Normal file
@ -0,0 +1,250 @@
|
||||
import functools
|
||||
import re
|
||||
from importlib.resources import as_file, files
|
||||
from typing import TypedDict, NamedTuple
|
||||
|
||||
import PIL.Image
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
|
||||
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
|
||||
_MODEL_PATH = 'vae-oid.npz'
|
||||
|
||||
_SEGMENT_DETECT_RE = re.compile(
|
||||
r'(.*?)' +
|
||||
r'<loc(\d{4})>' * 4 + r'\s*' +
|
||||
'(?:%s)?' % (r'<seg(\d{3})>' * 16) +
|
||||
r'\s*([^;<>]+)? ?(?:; )?',
|
||||
)
|
||||
|
||||
PALIGEMMA_OUTPUT_NAME = "PALIGEMMA_OUTPUT"
|
||||
|
||||
|
||||
class BoundingBox(NamedTuple):
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
|
||||
PaligemmaMask = Float[np.ndarray, "height width"]
|
||||
|
||||
|
||||
class ExtractedPaligemmaSegmented(TypedDict):
|
||||
content: str
|
||||
xyxy: BoundingBox
|
||||
mask: PaligemmaMask | None
|
||||
name: str
|
||||
|
||||
|
||||
class ExtractedPaligemmaContentOnly(TypedDict):
|
||||
content: str
|
||||
|
||||
|
||||
ExtractedPaligemmaObject = ExtractedPaligemmaSegmented | ExtractedPaligemmaContentOnly
|
||||
PostProcessResult = list[ExtractedPaligemmaObject]
|
||||
|
||||
|
||||
def _get_params(checkpoint):
|
||||
"""Converts PyTorch checkpoint to Flax params."""
|
||||
|
||||
def transp(kernel):
|
||||
return np.transpose(kernel, (2, 3, 1, 0))
|
||||
|
||||
def conv(name):
|
||||
return {
|
||||
'bias': checkpoint[name + '.bias'],
|
||||
'kernel': transp(checkpoint[name + '.weight']),
|
||||
}
|
||||
|
||||
def resblock(name):
|
||||
return {
|
||||
'Conv_0': conv(name + '.0'),
|
||||
'Conv_1': conv(name + '.2'),
|
||||
'Conv_2': conv(name + '.4'),
|
||||
}
|
||||
|
||||
return {
|
||||
'_embeddings': checkpoint['_vq_vae._embedding'],
|
||||
'Conv_0': conv('decoder.0'),
|
||||
'ResBlock_0': resblock('decoder.2.net'),
|
||||
'ResBlock_1': resblock('decoder.3.net'),
|
||||
'ConvTranspose_0': conv('decoder.4'),
|
||||
'ConvTranspose_1': conv('decoder.6'),
|
||||
'ConvTranspose_2': conv('decoder.8'),
|
||||
'ConvTranspose_3': conv('decoder.10'),
|
||||
'Conv_1': conv('decoder.12'),
|
||||
}
|
||||
|
||||
|
||||
def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
|
||||
batch_size, num_tokens = codebook_indices.shape
|
||||
assert num_tokens == 16, codebook_indices.shape
|
||||
unused_num_embeddings, embedding_dim = embeddings.shape
|
||||
|
||||
encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
|
||||
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
|
||||
return encodings
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_reconstruct_masks():
|
||||
"""Reconstructs masks from codebook indices.
|
||||
Returns:
|
||||
A function that expects indices shaped `[B, 16]` of dtype int32, each
|
||||
ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
|
||||
`[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
|
||||
"""
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
features: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
original_x = x
|
||||
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
|
||||
return x + original_x
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Upscales quantized vectors to mask."""
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
num_res_blocks = 2
|
||||
dim = 128
|
||||
num_upsample_layers = 4
|
||||
|
||||
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
|
||||
x = nn.relu(x)
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
x = ResBlock(features=dim)(x)
|
||||
|
||||
for _ in range(num_upsample_layers):
|
||||
x = nn.ConvTranspose(
|
||||
features=dim,
|
||||
kernel_size=(4, 4),
|
||||
strides=(2, 2),
|
||||
padding=2,
|
||||
transpose_kernel=True,
|
||||
)(x)
|
||||
x = nn.relu(x)
|
||||
dim //= 2
|
||||
|
||||
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
|
||||
|
||||
return x
|
||||
|
||||
def reconstruct_masks(codebook_indices):
|
||||
quantized = _quantized_values_from_codebook_indices(
|
||||
codebook_indices, params['_embeddings']
|
||||
)
|
||||
return Decoder().apply({'params': params}, quantized)
|
||||
|
||||
with as_file(files("comfy_extras.paligemma") / _MODEL_PATH) as f:
|
||||
params = _get_params(dict(np.load(f)))
|
||||
|
||||
return jax.jit(reconstruct_masks, backend='cpu')
|
||||
|
||||
|
||||
def extract_objs(text, width, height, unique_labels=False) -> PostProcessResult:
|
||||
"""Returns objs for a string with "<loc>" and "<seg>" tokens."""
|
||||
objs: list[ExtractedPaligemmaObject] = []
|
||||
seen = set()
|
||||
while text:
|
||||
m = _SEGMENT_DETECT_RE.match(text)
|
||||
if not m:
|
||||
break
|
||||
gs = list(m.groups())
|
||||
before = gs.pop(0)
|
||||
name = gs.pop()
|
||||
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
|
||||
|
||||
y1, x1, y2, x2 = map(round, (y1 * height, x1 * width, y2 * height, x2 * width))
|
||||
seg_indices = gs[4:20]
|
||||
if seg_indices[0] is None:
|
||||
mask = None
|
||||
else:
|
||||
seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
|
||||
m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
|
||||
m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
|
||||
m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
|
||||
mask = np.zeros([height, width])
|
||||
if y2 > y1 and x2 > x1:
|
||||
mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
|
||||
|
||||
content = m.group()
|
||||
if before:
|
||||
objs.append(dict(content=before))
|
||||
content = content[len(before):]
|
||||
while unique_labels and name in seen:
|
||||
name = (name or '') + "'"
|
||||
seen.add(name)
|
||||
paligemma_output_obj: ExtractedPaligemmaObject = {'content': content, 'xyxy': BoundingBox(x1, y1, x2, y2), 'mask': mask, 'name': name}
|
||||
objs.append(paligemma_output_obj)
|
||||
text = text[len(before) + len(content):]
|
||||
|
||||
if text:
|
||||
objs.append(dict(content=text))
|
||||
|
||||
return [obj for obj in objs if obj["content"] != '<eos>']
|
||||
|
||||
|
||||
class PaligemmaPostProcess(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"generated_text": ("STRING", {"forceInput": True}),
|
||||
},
|
||||
"optional": {
|
||||
"images": ("IMAGE", {}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = (PALIGEMMA_OUTPUT_NAME,)
|
||||
RETURN_NAMES = ("paligemma output",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, generated_text: str = "", task: str = "", images: RGBImageBatch = None) -> tuple[PostProcessResult]:
|
||||
return extract_objs(generated_text, images.shape[-2], images.shape[-3]),
|
||||
|
||||
|
||||
class PaligemmaOutputToMask(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"paligemma_output": (PALIGEMMA_OUTPUT_NAME, {"forceInput": True}),
|
||||
},
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = ("MASK",)
|
||||
RETURN_NAMES = ("paligemma output",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, paligemma_output: PostProcessResult) -> tuple[MaskBatch]:
|
||||
masks = [torch.from_numpy(p["mask"]) for p in paligemma_output if "mask" in p]
|
||||
if len(masks) == 0:
|
||||
return torch.zeroes((0, 0, 0)),
|
||||
return torch.stack(masks, dim=0),
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
PaligemmaOutputToMask,
|
||||
PaligemmaPostProcess,
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
0
comfy_extras/paligemma/__init__.py
Normal file
0
comfy_extras/paligemma/__init__.py
Normal file
BIN
comfy_extras/paligemma/vae-oid.npz
Normal file
BIN
comfy_extras/paligemma/vae-oid.npz
Normal file
Binary file not shown.
@ -1,4 +1,5 @@
|
||||
triton ;platform_system == 'Linux'
|
||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12'
|
||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11'
|
||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10'
|
||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp313-cp313-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.13'
|
||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12'
|
||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11'
|
||||
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10'
|
||||
@ -70,4 +70,6 @@ pebble>=5.0.7
|
||||
openai
|
||||
anthropic
|
||||
humanize
|
||||
lightning
|
||||
lightning
|
||||
flax
|
||||
jax
|
||||
16
tests/issues/__test_29_fix_str_in_model.py
Normal file
16
tests/issues/__test_29_fix_str_in_model.py
Normal file
@ -0,0 +1,16 @@
|
||||
import torch.nn
|
||||
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
class HasOperationsNoName(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.operations = object()
|
||||
if hasattr(self.operations, "__name__"):
|
||||
delattr(self.operations, "__name__")
|
||||
|
||||
|
||||
def test_str_model_patcher():
|
||||
model_patcher = ModelPatcher(HasOperationsNoName(), torch.device('cpu'), torch.device('cpu'))
|
||||
assert str(model_patcher) is not None
|
||||
Loading…
Reference in New Issue
Block a user