diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 9acf0d7af..53e747752 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -49,7 +49,7 @@ def _create_parser() -> EnhancedConfigArgParser: cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") - cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") + cm_group.add_argument("--disable-cuda-malloc", action="store_true", default=True, help="Disable cudaMallocAsync.") fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index cd64a1f48..a7a87cb01 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -142,7 +142,7 @@ class Configuration(dict): self.disable_auto_launch: bool = False self.cuda_device: Optional[int] = None self.cuda_malloc: bool = True - self.disable_cuda_malloc: bool = False + self.disable_cuda_malloc: bool = True self.dont_upcast_attention: bool = False self.force_upcast_attention: bool = False self.force_fp32: bool = False diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 3387a5ab8..cbccf1e00 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -290,8 +290,10 @@ def format_value(x) -> FormattedValue: return None elif isinstance(x, (int, float, bool, str)): return x - else: + elif isinstance(x, dict) and not any(isinstance(v, torch.Tensor) for v in x.values()): return str(x) + else: + return str(x.__class__) def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple: diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index f7484c237..7cb072e7e 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -63,7 +63,8 @@ except Exception: os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - +os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" +os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" def _fix_pytorch_240(): """Fixes pytorch 2.4.0""" diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index e3980c254..82ab7f4f5 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import copy import inspect import logging @@ -29,7 +30,8 @@ from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2'] # should be added if the expectation is that this model emits special tokens -_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2'} +_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'} + class TransformersManagedModel(ModelManageable, LanguageModel): def __init__( @@ -214,7 +216,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel): text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True) - with seed_for_block(seed): + try: + import triton # pylint: disable=import-error + has_triton = True + except (ImportError, ModuleNotFoundError): + has_triton = False + + with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext(): if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs: inputs.pop("attention_mask") output_ids = transformers_model.generate( @@ -310,7 +318,6 @@ class TransformersManagedModel(ModelManageable, LanguageModel): def model_dtype(self) -> torch.dtype: return self.model.dtype - def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: return self.model.to(device=device_to) @@ -344,11 +351,12 @@ class TransformersManagedModel(ModelManageable, LanguageModel): if isinstance(images, list): images = torch.stack(images, dim=0) if images is not None: - # PIL it for the sake of simplicity image_sizes = [(image.shape[-2], image.shape[-3]) for image in images] else: image_sizes = [] - images = [] + # todo: what is the best choice for this? + # probably select a size that related to the vision tower? + images = torch.zeroes((0, 0, 0, 3)) try: if hasattr(tokenizer, "apply_chat_template"): @@ -383,8 +391,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel): else: if hasattr(self.processor, "to"): self.processor.to(device=self.load_device) - - batch_feature: BatchFeature = self.processor(text=[prompt], images=images.unbind(), return_tensors="pt", padding=True) + # convert tuple to list from images.unbind() for paligemma workaround + batch_feature: BatchFeature = self.processor(text=[prompt], images=list(images.unbind()) if images is not None and len(images) > 0 else None, return_tensors="pt", padding=True) if hasattr(self.processor, "to"): self.processor.to(device=self.offload_device) assert "input_ids" in batch_feature diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index ca3b28d13..dcd1cf7c9 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -452,6 +452,10 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = { 'roborovski/superprompt-v1', 'Qwen/Qwen2-VL-7B-Instruct', 'microsoft/Florence-2-large-ft', + 'google/paligemma2-10b-pt-896', + 'google/paligemma2-28b-pt-896', + 'google/paligemma-3b-ft-refcoco-seg-896', + 'microsoft/phi-4', } KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/model_management.py b/comfy/model_management.py index e0ccca1a8..1062f6ad8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -551,7 +551,7 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, _load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) to_load = list(map(str, models)) span.set_attribute("models", to_load) - logger.info(f"Loaded {to_load}") + logger.debug(f"Loaded {to_load}") def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None: @@ -627,6 +627,8 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0 loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) + logger.debug(f"Loaded {loaded_model}") + span = get_current_span() span.set_attribute("models_to_load", list(map(str, models_to_load))) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c7d31ddd1..91baa688b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -265,9 +265,6 @@ class ModelPatcher(ModelManageable): def lowvram_patch_counter(self): return self._memory_measurements.lowvram_patch_counter - if not hasattr(self.model, 'current_weight_patches_uuid'): - self.model.current_weight_patches_uuid = None - def model_size(self): if self.size > 0: return self.size @@ -845,7 +842,10 @@ class ModelPatcher(ModelManageable): def __str__(self): if hasattr(self.model, "operations"): - operations_str = self.model.operations.__name__ + if hasattr(self.model.operations, "__name__"): + operations_str = self.model.operations.__name__ + else: + operations_str = str(self.model.operations) else: operations_str = None info_str = f"model_dtype={self.model_dtype()} device={self.model_device} size={naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)} operations={operations_str}" diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 8bf97efce..dd15379d9 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -63,6 +63,10 @@ NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any] InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec] +# numpy seeds must be between 0 and 2**32 - 1 +Seed = ("INT", {"default": 0, "min": 0, "max": 2**32 - 1}) +SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})] + class HiddenSpec(TypedDict, total=True): prompt: Literal["PROMPT"] diff --git a/comfy_extras/chainner_models/__init__.py b/comfy_extras/chainner_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_extras/constants/__init__.py b/comfy_extras/constants/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_extras/controlnet_aux/utils.py b/comfy_extras/controlnet_aux/utils.py index 3dc9ca1ed..fbe3d40ec 100644 --- a/comfy_extras/controlnet_aux/utils.py +++ b/comfy_extras/controlnet_aux/utils.py @@ -23,6 +23,7 @@ import numpy as np import torch from comfy.nodes.common import MAX_RESOLUTION +from comfy.nodes.package_typing import Seed from comfy.utils import ProgressBar import logging as log # Sync with theoritical limit from Comfy base @@ -73,7 +74,7 @@ class INPUT(Enum): def MASK(): return ("MASK",) def SEED(default=0): - return ("INT", dict(default=default, min=0, max=0xffffffffffffffff)) + return Seed def RESOLUTION(default=512, min=64, max=MAX_RESOLUTION, step=64): return ("INT", dict(default=default, min=min, max=max, step=step)) def INT(default=0, min=0, max=MAX_RESOLUTION, step=1): diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index a3e599f2d..bb521f868 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -8,6 +8,7 @@ from comfy.cmd import latent_preview import torch from comfy import utils from comfy import node_helpers +from comfy.nodes.package_typing import Seed from comfy.samplers import KSAMPLER @@ -458,7 +459,7 @@ class SamplerCustom: return {"required": {"model": ("MODEL",), "add_noise": ("BOOLEAN", {"default": True}), - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "noise_seed": Seed, "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), @@ -610,7 +611,7 @@ class RandomNoise(DisableNoise): @classmethod def INPUT_TYPES(s): return {"required":{ - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "noise_seed": Seed, } } diff --git a/comfy_extras/nodes/nodes_florence2.py b/comfy_extras/nodes/nodes_florence2.py index 93d99a932..d274d1cdc 100644 --- a/comfy_extras/nodes/nodes_florence2.py +++ b/comfy_extras/nodes/nodes_florence2.py @@ -138,7 +138,7 @@ class Florence2PostProcess(CustomNode): return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])), -class Florence2OutputToPolygon(CustomNode): +class Florence2OutputToMask(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { @@ -166,6 +166,6 @@ NODE_CLASS_MAPPINGS = {} for cls in ( Florence2PostProcess, Florence2TaskTokenize, - Florence2OutputToPolygon + Florence2OutputToMask ): NODE_CLASS_MAPPINGS[cls.__name__] = cls \ No newline at end of file diff --git a/comfy_extras/nodes/nodes_ideogram.py b/comfy_extras/nodes/nodes_ideogram.py index 7444d09d3..8ccfc11b1 100644 --- a/comfy_extras/nodes/nodes_ideogram.py +++ b/comfy_extras/nodes/nodes_ideogram.py @@ -9,7 +9,7 @@ import torch from PIL import Image from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch -from comfy.nodes.package_typing import CustomNode +from comfy.nodes.package_typing import CustomNode, Seed from comfy.utils import pil2tensor, tensor2pil from comfy_extras.constants.resolutions import IDEOGRAM_RESOLUTIONS from comfy_extras.nodes.nodes_mask import MaskToImage @@ -46,7 +46,7 @@ class IdeogramGenerate(CustomNode): "api_key": ("STRING", {"default": ""}), "negative_prompt": ("STRING", {"multiline": True}), "num_images": ("INT", {"default": 1, "min": 1, "max": 8}), - "seed": ("INT", {"default": 0}), + "seed": Seed, } } diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index b81d0de29..591a01aaf 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -20,7 +20,7 @@ from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWA from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device -from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult +from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed _AUTO_CHAT_TEMPLATE = "default" @@ -339,7 +339,7 @@ class TransformersGenerate(CustomNode): "tokens": (TOKENS_TYPE_NAME, {}), "max_new_tokens": ("INT", {"default": 512, "min": 1}), "repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}), - "seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}), + "seed": Seed, }, "optional": { "sampler": (GENERATION_KWARGS_TYPE_NAME, {}), diff --git a/comfy_extras/nodes/nodes_latent.py b/comfy_extras/nodes/nodes_latent.py index 96b2120b8..43e66a1af 100644 --- a/comfy_extras/nodes/nodes_latent.py +++ b/comfy_extras/nodes/nodes_latent.py @@ -3,6 +3,7 @@ import torch import comfy.utils from comfy.component_model.tensor_types import Latent +from comfy.nodes.package_typing import Seed from .nodes_post_processing import gaussian_kernel @@ -168,7 +169,7 @@ class LatentAddNoiseChannels: "required": { "samples": ("LATENT",), "std_dev": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "seed": Seed, "slice_i": ("INT", {"default": 0, "min": -16, "max": 16}), "slice_j": ("INT", {"default": 16, "min": -16, "max": 16}), } diff --git a/comfy_extras/nodes/nodes_paligemma.py b/comfy_extras/nodes/nodes_paligemma.py new file mode 100644 index 000000000..8d13acff2 --- /dev/null +++ b/comfy_extras/nodes/nodes_paligemma.py @@ -0,0 +1,250 @@ +import functools +import re +from importlib.resources import as_file, files +from typing import TypedDict, NamedTuple + +import PIL.Image +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import torch +from jaxtyping import Float + +from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch +from comfy.nodes.package_typing import CustomNode, InputTypes + +_MODEL_PATH = 'vae-oid.npz' + +_SEGMENT_DETECT_RE = re.compile( + r'(.*?)' + + r'' * 4 + r'\s*' + + '(?:%s)?' % (r'' * 16) + + r'\s*([^;<>]+)? ?(?:; )?', +) + +PALIGEMMA_OUTPUT_NAME = "PALIGEMMA_OUTPUT" + + +class BoundingBox(NamedTuple): + x1: int + y1: int + x2: int + y2: int + + +PaligemmaMask = Float[np.ndarray, "height width"] + + +class ExtractedPaligemmaSegmented(TypedDict): + content: str + xyxy: BoundingBox + mask: PaligemmaMask | None + name: str + + +class ExtractedPaligemmaContentOnly(TypedDict): + content: str + + +ExtractedPaligemmaObject = ExtractedPaligemmaSegmented | ExtractedPaligemmaContentOnly +PostProcessResult = list[ExtractedPaligemmaObject] + + +def _get_params(checkpoint): + """Converts PyTorch checkpoint to Flax params.""" + + def transp(kernel): + return np.transpose(kernel, (2, 3, 1, 0)) + + def conv(name): + return { + 'bias': checkpoint[name + '.bias'], + 'kernel': transp(checkpoint[name + '.weight']), + } + + def resblock(name): + return { + 'Conv_0': conv(name + '.0'), + 'Conv_1': conv(name + '.2'), + 'Conv_2': conv(name + '.4'), + } + + return { + '_embeddings': checkpoint['_vq_vae._embedding'], + 'Conv_0': conv('decoder.0'), + 'ResBlock_0': resblock('decoder.2.net'), + 'ResBlock_1': resblock('decoder.3.net'), + 'ConvTranspose_0': conv('decoder.4'), + 'ConvTranspose_1': conv('decoder.6'), + 'ConvTranspose_2': conv('decoder.8'), + 'ConvTranspose_3': conv('decoder.10'), + 'Conv_1': conv('decoder.12'), + } + + +def _quantized_values_from_codebook_indices(codebook_indices, embeddings): + batch_size, num_tokens = codebook_indices.shape + assert num_tokens == 16, codebook_indices.shape + unused_num_embeddings, embedding_dim = embeddings.shape + + encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) + encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) + return encodings + + +@functools.cache +def _get_reconstruct_masks(): + """Reconstructs masks from codebook indices. + Returns: + A function that expects indices shaped `[B, 16]` of dtype int32, each + ranging from 0 to 127 (inclusive), and that returns a decoded masks sized + `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1]. + """ + + class ResBlock(nn.Module): + features: int + + @nn.compact + def __call__(self, x): + original_x = x + x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) + x = nn.relu(x) + x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) + x = nn.relu(x) + x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) + return x + original_x + + class Decoder(nn.Module): + """Upscales quantized vectors to mask.""" + + @nn.compact + def __call__(self, x): + num_res_blocks = 2 + dim = 128 + num_upsample_layers = 4 + + x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) + x = nn.relu(x) + + for _ in range(num_res_blocks): + x = ResBlock(features=dim)(x) + + for _ in range(num_upsample_layers): + x = nn.ConvTranspose( + features=dim, + kernel_size=(4, 4), + strides=(2, 2), + padding=2, + transpose_kernel=True, + )(x) + x = nn.relu(x) + dim //= 2 + + x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) + + return x + + def reconstruct_masks(codebook_indices): + quantized = _quantized_values_from_codebook_indices( + codebook_indices, params['_embeddings'] + ) + return Decoder().apply({'params': params}, quantized) + + with as_file(files("comfy_extras.paligemma") / _MODEL_PATH) as f: + params = _get_params(dict(np.load(f))) + + return jax.jit(reconstruct_masks, backend='cpu') + + +def extract_objs(text, width, height, unique_labels=False) -> PostProcessResult: + """Returns objs for a string with "" and "" tokens.""" + objs: list[ExtractedPaligemmaObject] = [] + seen = set() + while text: + m = _SEGMENT_DETECT_RE.match(text) + if not m: + break + gs = list(m.groups()) + before = gs.pop(0) + name = gs.pop() + y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] + + y1, x1, y2, x2 = map(round, (y1 * height, x1 * width, y2 * height, x2 * width)) + seg_indices = gs[4:20] + if seg_indices[0] is None: + mask = None + else: + seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32) + m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0] + m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1) + m64 = PIL.Image.fromarray((m64 * 255).astype('uint8')) + mask = np.zeros([height, width]) + if y2 > y1 and x2 > x1: + mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0 + + content = m.group() + if before: + objs.append(dict(content=before)) + content = content[len(before):] + while unique_labels and name in seen: + name = (name or '') + "'" + seen.add(name) + paligemma_output_obj: ExtractedPaligemmaObject = {'content': content, 'xyxy': BoundingBox(x1, y1, x2, y2), 'mask': mask, 'name': name} + objs.append(paligemma_output_obj) + text = text[len(before) + len(content):] + + if text: + objs.append(dict(content=text)) + + return [obj for obj in objs if obj["content"] != ''] + + +class PaligemmaPostProcess(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "generated_text": ("STRING", {"forceInput": True}), + }, + "optional": { + "images": ("IMAGE", {}), + } + } + + CATEGORY = "language" + RETURN_TYPES = (PALIGEMMA_OUTPUT_NAME,) + RETURN_NAMES = ("paligemma output",) + FUNCTION = "execute" + + def execute(self, generated_text: str = "", task: str = "", images: RGBImageBatch = None) -> tuple[PostProcessResult]: + return extract_objs(generated_text, images.shape[-2], images.shape[-3]), + + +class PaligemmaOutputToMask(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "paligemma_output": (PALIGEMMA_OUTPUT_NAME, {"forceInput": True}), + }, + } + + CATEGORY = "language" + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("paligemma output",) + FUNCTION = "execute" + + def execute(self, paligemma_output: PostProcessResult) -> tuple[MaskBatch]: + masks = [torch.from_numpy(p["mask"]) for p in paligemma_output if "mask" in p] + if len(masks) == 0: + return torch.zeroes((0, 0, 0)), + return torch.stack(masks, dim=0), + + +NODE_CLASS_MAPPINGS = {} +for cls in ( + PaligemmaOutputToMask, + PaligemmaPostProcess, +): + NODE_CLASS_MAPPINGS[cls.__name__] = cls diff --git a/comfy_extras/paligemma/__init__.py b/comfy_extras/paligemma/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_extras/paligemma/vae-oid.npz b/comfy_extras/paligemma/vae-oid.npz new file mode 100644 index 000000000..c4abdbe0f Binary files /dev/null and b/comfy_extras/paligemma/vae-oid.npz differ diff --git a/requirements-triton.txt b/requirements-triton.txt index 6505a453b..0d675bf3a 100644 --- a/requirements-triton.txt +++ b/requirements-triton.txt @@ -1,4 +1,5 @@ triton ;platform_system == 'Linux' -triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12' -triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11' -triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post9/triton-3.1.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10' \ No newline at end of file +triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp313-cp313-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.13' +triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12' +triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11' +triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post9/triton-3.2.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e36d19d22..8568112fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,4 +70,6 @@ pebble>=5.0.7 openai anthropic humanize -lightning \ No newline at end of file +lightning +flax +jax \ No newline at end of file diff --git a/tests/issues/__test_29_fix_str_in_model.py b/tests/issues/__test_29_fix_str_in_model.py new file mode 100644 index 000000000..dc2d43c65 --- /dev/null +++ b/tests/issues/__test_29_fix_str_in_model.py @@ -0,0 +1,16 @@ +import torch.nn + +from comfy.model_patcher import ModelPatcher + + +class HasOperationsNoName(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.operations = object() + if hasattr(self.operations, "__name__"): + delattr(self.operations, "__name__") + + +def test_str_model_patcher(): + model_patcher = ModelPatcher(HasOperationsNoName(), torch.device('cpu'), torch.device('cpu')) + assert str(model_patcher) is not None