From 71698f409913b142f0423c30a63ec3011f5f1825 Mon Sep 17 00:00:00 2001 From: doctorpangloss <2229300+doctorpangloss@users.noreply.github.com> Date: Fri, 26 Dec 2025 15:20:45 -0800 Subject: [PATCH] tests pass --- comfy/api/openapi.yaml | 22 ++-- comfy/client/aio_client.py | 52 +++++++- comfy/client/client_types.py | 1 + comfy/client/embedded_comfy_client.py | 8 +- comfy/cmd/server.py | 5 +- comfy/samplers.py | 5 +- comfy/taesd/taehv.py | 48 ++++--- comfy/text_encoders/jina_clip_2.py | 41 ++++-- comfy/text_encoders/newbie.py | 34 +++-- comfy/text_encoders/sd3_clip.py | 4 +- comfy_extras/nodes/nodes_qwen.py | 3 +- tests/execution/test_execution.py | 42 +++--- tests/execution/test_jobs_from_execution.py | 135 ++++++++++++++++++++ 13 files changed, 314 insertions(+), 86 deletions(-) create mode 100644 tests/execution/test_jobs_from_execution.py diff --git a/comfy/api/openapi.yaml b/comfy/api/openapi.yaml index 260c75a38..d19071167 100644 --- a/comfy/api/openapi.yaml +++ b/comfy/api/openapi.yaml @@ -414,20 +414,6 @@ paths: responses: 200: headers: - Idempotency-Key: - description: | - The API supports idempotency for safely retrying requests without accidentally performing the same operation twice. When creating or updating an object, use an idempotency key. Then, if a connection error occurs, you can safely repeat the request without risk of creating a second object or performing the update twice. - - To perform an idempotent request, provide an additional IdempotencyKey element to the request options. - - Idempotency works by saving the resulting status code and body of the first request made for any given idempotency key, regardless of whether it succeeds or fails. Subsequent requests with the same key return the same result, including 500 errors. - - A client generates an idempotency key, which is a unique key that the server uses to recognize subsequent retries of the same request. How you create unique keys is up to you, but we suggest using V4 UUIDs, or another random string with enough entropy to avoid collisions. Idempotency keys are up to 255 characters long. - - You can remove keys from the system automatically after they’re at least 24 hours old. We generate a new request if a key is reused after the original is pruned. The idempotency layer compares incoming parameters to those of the original request and errors if they’re the same to prevent accidental misuse. - example: XFDSF000213 - schema: - type: string Digest: description: The digest of the request body example: SHA256=e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3 @@ -439,6 +425,10 @@ paths: schema: type: string pattern: '^filename=.+' + Location: + description: The relative URL to revisit this request + schema: + type: string description: | The content of the last SaveImage node. content: @@ -467,6 +457,10 @@ paths: type: array items: type: string + prompt_id: + type: string + description: + The ID of the prompt that was queued and executed outputs: $ref: "#/components/schemas/Outputs" example: diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index f79f44ac1..99f32bbc4 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -1,3 +1,4 @@ +from __future__ import annotations import asyncio import uuid from asyncio import AbstractEventLoop @@ -65,7 +66,7 @@ class AsyncRemoteComfyClient: return headers @tracer.start_as_current_span("Post Prompt") - async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse: + async def _post_prompt(self, prompt: PromptDict | dict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse: """ Common method to POST a prompt to a given endpoint. :param prompt: The prompt to send @@ -101,7 +102,7 @@ class AsyncRemoteComfyClient: else: raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}") - async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse: + async def queue_prompt_api(self, prompt: PromptDict | dict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse: """ Calls the API to queue a prompt. :param prompt: @@ -211,3 +212,50 @@ class AsyncRemoteComfyClient: return response.status, None # Timeout return 408, None + + async def get_jobs(self, status: Optional[str] = None, workflow_id: Optional[str] = None, + limit: Optional[int] = None, offset: Optional[int] = None, + sort_by: Optional[str] = None, sort_order: Optional[str] = None) -> dict: + """ + List all jobs with filtering, sorting, and pagination. + :param status: Filter by status (comma-separated): pending, in_progress, completed, failed + :param workflow_id: Filter by workflow ID + :param limit: Max items to return + :param offset: Items to skip + :param sort_by: Sort field: created_at (default), execution_duration + :param sort_order: Sort direction: asc, desc (default) + :return: Dictionary containing jobs list and pagination info + """ + params = {} + if status is not None: + params["status"] = status + if workflow_id is not None: + params["workflow_id"] = workflow_id + if limit is not None: + params["limit"] = str(limit) + if offset is not None: + params["offset"] = str(offset) + if sort_by is not None: + params["sort_by"] = sort_by + if sort_order is not None: + params["sort_order"] = sort_order + + async with self.session.get(urljoin(self.server_address, "/api/jobs"), params=params) as response: + if response.status == 200: + return await response.json() + else: + raise RuntimeError(f"could not get jobs: {response.status}: {await response.text()}") + + async def get_job(self, job_id: str) -> Optional[dict]: + """ + Get a single job by ID. + :param job_id: The job ID + :return: Job dictionary or None if not found + """ + async with self.session.get(urljoin(self.server_address, f"/api/jobs/{job_id}")) as response: + if response.status == 200: + return await response.json() + elif response.status == 404: + return None + else: + raise RuntimeError(f"could not get job: {response.status}: {await response.text()}") diff --git a/comfy/client/client_types.py b/comfy/client/client_types.py index 7f1d55ec2..5cfda941c 100644 --- a/comfy/client/client_types.py +++ b/comfy/client/client_types.py @@ -24,6 +24,7 @@ class Output(TypedDict, total=False): class V1QueuePromptResponse: urls: List[str] outputs: dict[str, Output] + prompt_id: str class ProgressNotification(NamedTuple): diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index cde995430..1466aee82 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -382,19 +382,21 @@ class Comfy: async def queue_prompt_api(self, prompt: PromptDict | str | dict, - progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse: + progress_handler: Optional[ExecutorToClientProgress] = None, + prompt_id: Optional[str] = None) -> V1QueuePromptResponse: """ Queues a prompt for execution, returning the output when it is complete. :param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt :return: a response of URLs for Save-related nodes and the node outputs """ + prompt_id = prompt_id or str(uuid.uuid4()) if isinstance(prompt, str): prompt = json.loads(prompt) if isinstance(prompt, dict): from ..api.components.schema.prompt import Prompt prompt = Prompt.validate(prompt) - outputs = await self.queue_prompt(prompt, progress_handler=progress_handler) - return V1QueuePromptResponse(urls=[], outputs=outputs) + outputs = await self.queue_prompt(prompt, progress_handler=progress_handler, prompt_id=prompt_id) + return V1QueuePromptResponse(urls=[], outputs=outputs, prompt_id=prompt_id) def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress: """ diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index f09e97fae..06dca1cfb 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -1193,6 +1193,7 @@ class PromptServer(ExecutorToClientProgress): filename = main_image["filename"] digest_headers_ = { "Digest": f"SHA-256={content_digest}", + "Location": f"/api/v1/prompts/{task_id}" } urls_ = [] if len(output_images) == 1: @@ -1224,8 +1225,10 @@ class PromptServer(ExecutorToClientProgress): headers=digest_headers_, body=json.dumps({ 'urls': urls_, - 'outputs': result.outputs + 'outputs': result.outputs, + 'prompt_id': task_id, })) + # todo: provide more ways to accept these, or support multiple file returns easily elif accept == "image/png" or accept == "image/jpeg": return web.FileResponse(main_image["abs_path"], headers=digest_headers_) diff --git a/comfy/samplers.py b/comfy/samplers.py index e9f6cd746..356c653d6 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -23,6 +23,7 @@ from .k_diffusion import sampling as k_diffusion_sampling from .model_base import BaseModel from .model_management_types import ModelOptions from .model_patcher import ModelPatcher +from .sampler_helpers import prepare_mask from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES, KSAMPLER_NAMES from .context_windows import ContextHandlerABC from .utils import common_upscale, pack_latents, unpack_latents @@ -1068,10 +1069,10 @@ class CFGGuider: denoise_masks.append(torch.ones(latent_shapes[i])) for i in range(len(denoise_masks)): - denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device) + denoise_masks[i] = prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device) if len(denoise_masks) > 1: - denoise_mask, _ = comfy.utils.pack_latents(denoise_masks) + denoise_mask, _ = pack_latents(denoise_masks) else: denoise_mask = denoise_masks[0] diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index a466fbee0..aaa5326a7 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -6,59 +6,68 @@ import torch.nn.functional as F from tqdm.auto import tqdm from collections import namedtuple, deque -import comfy.ops -operations=comfy.ops.disable_weight_init +from ..ops import disable_weight_init + +operations = disable_weight_init DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + def conv(n_in, n_out, **kwargs): return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + class Clamp(nn.Module): def forward(self, x): return torch.tanh(x / 3) * 3 + class MemBlock(nn.Module): def __init__(self, n_in, n_out, act_func): super().__init__() self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out)) self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.act = act_func + def forward(self, x, past): return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + class TPool(nn.Module): def __init__(self, n_f, stride): super().__init__() self.stride = stride - self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False) + self.conv = operations.Conv2d(n_f * stride, n_f, 1, bias=False) + def forward(self, x): _NT, C, H, W = x.shape return self.conv(x.reshape(-1, self.stride * C, H, W)) + class TGrow(nn.Module): def __init__(self, n_f, stride): super().__init__() self.stride = stride - self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False) + self.conv = operations.Conv2d(n_f, n_f * stride, 1, bias=False) + def forward(self, x): _NT, C, H, W = x.shape x = self.conv(x) return x.reshape(-1, C, H, W) -def apply_model_with_memblocks(model, x, parallel, show_progress_bar): +def apply_model_with_memblocks(model, x, parallel, show_progress_bar): B, T, C, H, W = x.shape if parallel: - x = x.reshape(B*T, C, H, W) + x = x.reshape(B * T, C, H, W) # parallel over input timesteps, iterate over blocks for b in tqdm(model, disable=not show_progress_bar): if isinstance(b, MemBlock): BT, C, H, W = x.shape T = BT // B _x = x.reshape(B, T, C, H, W) - mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape) x = b(x, mem) else: x = b(x) @@ -87,25 +96,25 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): xt_new = b(xt, mem[i]) mem[i] = xt.detach().clone() del xt - work_queue.appendleft(TWorkItem(xt_new, i+1)) + work_queue.appendleft(TWorkItem(xt_new, i + 1)) elif isinstance(b, TPool): if mem[i] is None: mem[i] = [] mem[i].append(xt.detach().clone()) if len(mem[i]) == b.stride: B, C, H, W = xt.shape - xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W)) + xt = b(torch.cat(mem[i], 1).view(B * b.stride, C, H, W)) mem[i] = [] - work_queue.appendleft(TWorkItem(xt, i+1)) + work_queue.appendleft(TWorkItem(xt, i + 1)) elif isinstance(b, TGrow): xt = b(xt) NT, C, H, W = xt.shape - for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)): - work_queue.appendleft(TWorkItem(xt_next, i+1)) + for xt_next in reversed(xt.view(B, b.stride * C, H, W).chunk(b.stride, 1)): + work_queue.appendleft(TWorkItem(xt_next, i + 1)) del xt else: xt = b(xt) - work_queue.appendleft(TWorkItem(xt, i+1)) + work_queue.appendleft(TWorkItem(xt, i + 1)) progress_bar.close() x = torch.stack(out, 1) return x @@ -122,29 +131,30 @@ class TAEHV(nn.Module): self.show_progress_bar = show_progress_bar self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x) self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) - if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 + if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 self.patch_size = 2 - if self.latent_channels == 32: # HunyuanVideo1.5 + if self.latent_channels == 32: # HunyuanVideo1.5 act_func = nn.LeakyReLU(0.2, inplace=True) - else: # HunyuanVideo, Wan 2.1 + else: # HunyuanVideo, Wan 2.1 act_func = nn.ReLU(inplace=True) self.encoder = nn.Sequential( - conv(self.image_channels*self.patch_size**2, 64), act_func, + conv(self.image_channels * self.patch_size ** 2, 64), act_func, TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), conv(64, self.latent_channels), ) n_f = [256, 128, 64, 64] - self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1 self.decoder = nn.Sequential( Clamp(), conv(self.latent_channels, n_f[0]), act_func, MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), - act_func, conv(n_f[3], self.image_channels*self.patch_size**2), + act_func, conv(n_f[3], self.image_channels * self.patch_size ** 2), ) + @property def show_progress_bar(self): return self._show_progress_bar diff --git a/comfy/text_encoders/jina_clip_2.py b/comfy/text_encoders/jina_clip_2.py index 0cffb6d16..3ab982f98 100644 --- a/comfy/text_encoders/jina_clip_2.py +++ b/comfy/text_encoders/jina_clip_2.py @@ -8,13 +8,15 @@ import torch from torch import nn as nn from torch.nn import functional as F -import comfy.model_management -import comfy.ops -from comfy import sd1_clip from .spiece_tokenizer import SPieceTokenizer +from .. import sd1_clip +from ..ldm.modules.attention import optimized_attention_for_device + class JinaClip2Tokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} tokenizer = tokenizer_data.get("spiece_model", None) # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192 super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data) @@ -22,10 +24,14 @@ class JinaClip2Tokenizer(sd1_clip.SDTokenizer): def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} + class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2") + # https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json @dataclass class XLMRobertaConfig: @@ -44,6 +50,7 @@ class XLMRobertaConfig: eos_token_id: int = 2 pad_token_id: int = 1 + class XLMRobertaEmbeddings(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() @@ -61,6 +68,7 @@ class XLMRobertaEmbeddings(nn.Module): embeddings = embeddings + token_type_embeddings return embeddings + class RotaryEmbedding(nn.Module): def __init__(self, dim, base, device=None): super().__init__() @@ -95,6 +103,7 @@ class RotaryEmbedding(nn.Module): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + class MHA(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() @@ -122,6 +131,7 @@ class MHA(nn.Module): out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True) return self.out_proj(out) + class MLP(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() @@ -135,6 +145,7 @@ class MLP(nn.Module): x = self.fc2(x) return x + class Block(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() @@ -152,17 +163,19 @@ class Block(nn.Module): hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states) return hidden_states + class XLMRobertaEncoder(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask=None): - optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) + optimized_attention = optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) for layer in self.layers: hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention) return hidden_states + class XLMRobertaModel_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() @@ -171,7 +184,9 @@ class XLMRobertaModel_(nn.Module): self.emb_drop = nn.Dropout(config.hidden_dropout_prob) self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops) - def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=None): + if embeds_info is None: + embeds_info = [] x = self.embeddings(input_ids=input_ids, embeddings=embeds) x = self.emb_ln(x) x = self.emb_drop(x) @@ -194,6 +209,7 @@ class XLMRobertaModel_(nn.Module): # Intermediate output is not yet implemented, use None for placeholder return sequence_output, None, pooled_output + class XLMRobertaModel(nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -210,10 +226,17 @@ class XLMRobertaModel(nn.Module): def forward(self, *args, **kwargs): return self.model(*args, **kwargs) + class JinaClip2TextModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}): + def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None): + if model_options is None: + model_options = {} super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) + class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}): + def __init__(self, device="cpu", dtype=None, model_options=None): + if model_options is None: + model_options = {} super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options) + diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py index db2324576..a898a1351 100644 --- a/comfy/text_encoders/newbie.py +++ b/comfy/text_encoders/newbie.py @@ -1,15 +1,18 @@ import torch -import comfy.model_management -import comfy.text_encoders.jina_clip_2 -import comfy.text_encoders.lumina2 +from ..model_management import pick_weight_dtype +from .jina_clip_2 import JinaClip2TextModel, JinaClip2Tokenizer +from .lumina2 import Gemma3_4BTokenizer, Gemma3_4BModel + class NewBieTokenizer: - def __init__(self, embedding_directory=None, tokenizer_data={}): - self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) - self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + self.gemma = Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) + self.jina = JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) - def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): out = {} out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs) out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs) @@ -21,12 +24,15 @@ class NewBieTokenizer: def state_dict(self): return {} + class NewBieTEModel(torch.nn.Module): - def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): + def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options=None): super().__init__() - dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) - self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) - self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) + if model_options is None: + model_options = {} + dtype_gemma = pick_weight_dtype(dtype_gemma, dtype, device) + self.gemma = Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) + self.jina = JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) self.dtypes = {dtype, dtype_gemma} def set_clip_options(self, options): @@ -52,11 +58,15 @@ class NewBieTEModel(torch.nn.Module): else: return self.jina.load_sd(sd) + def te(dtype_llama=None, llama_quantization_metadata=None): class NewBieTEModel_(NewBieTEModel): - def __init__(self, device="cpu", dtype=None, model_options={}): + def __init__(self, device="cpu", dtype=None, model_options=None): + if model_options is None: + model_options = {} if llama_quantization_metadata is not None: model_options = model_options.copy() model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return NewBieTEModel_ diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index ac99e413d..1f7cf31fc 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -1,6 +1,6 @@ import copy import logging -import comfy.utils +from ..utils import detect_layer_quantization import torch from ..transformers_compat import T5TokenizerFast @@ -33,7 +33,7 @@ def t5_xxl_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_t5"] = state_dict[t5_key].dtype - quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + quant = detect_layer_quantization(state_dict, prefix) if quant is not None: out["t5_quantization_metadata"] = quant diff --git a/comfy_extras/nodes/nodes_qwen.py b/comfy_extras/nodes/nodes_qwen.py index 9870873c0..427dbfc35 100644 --- a/comfy_extras/nodes/nodes_qwen.py +++ b/comfy_extras/nodes/nodes_qwen.py @@ -5,7 +5,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io import comfy.model_management import torch -import nodes +from comfy.nodes import base_nodes as nodes + class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index 447fe9bcc..b0ba46ef8 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -845,13 +845,13 @@ class TestExecution: assert numpy.array(result_image).mean() == 0, "Image should be black" # Jobs API tests - def test_jobs_api_job_structure( + async def test_jobs_api_job_structure( self, client: ComfyClient, builder: GraphBuilder ): """Test that job objects have required fields""" - self._create_history_item(client, builder) + await self._create_history_item(client, builder) - jobs_response = client.get_jobs(status="completed", limit=1) + jobs_response = await client.get_jobs(status="completed", limit=1) assert len(jobs_response["jobs"]) > 0, "Should have at least one job" job = jobs_response["jobs"][0] @@ -861,13 +861,13 @@ class TestExecution: assert "outputs_count" in job, "Job should have outputs_count" assert "preview_output" in job, "Job should have preview_output" - def test_jobs_api_preview_output_structure( + async def test_jobs_api_preview_output_structure( self, client: ComfyClient, builder: GraphBuilder ): """Test that preview_output has correct structure""" - self._create_history_item(client, builder) + await self._create_history_item(client, builder) - jobs_response = client.get_jobs(status="completed", limit=1) + jobs_response = await client.get_jobs(status="completed", limit=1) job = jobs_response["jobs"][0] if job["preview_output"] is not None: @@ -876,15 +876,15 @@ class TestExecution: assert "nodeId" in preview, "Preview should have nodeId" assert "mediaType" in preview, "Preview should have mediaType" - def test_jobs_api_pagination( + async def test_jobs_api_pagination( self, client: ComfyClient, builder: GraphBuilder ): """Test jobs API pagination""" for _ in range(5): - self._create_history_item(client, builder) + await self._create_history_item(client, builder) - first_page = client.get_jobs(limit=2, offset=0) - second_page = client.get_jobs(limit=2, offset=2) + first_page = await client.get_jobs(limit=2, offset=0) + second_page = await client.get_jobs(limit=2, offset=2) assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs" assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs" @@ -893,15 +893,15 @@ class TestExecution: second_ids = {j["id"] for j in second_page["jobs"]} assert first_ids.isdisjoint(second_ids), "Pages should have different jobs" - def test_jobs_api_sorting( + async def test_jobs_api_sorting( self, client: ComfyClient, builder: GraphBuilder ): """Test jobs API sorting""" for _ in range(3): - self._create_history_item(client, builder) + await self._create_history_item(client, builder) - desc_jobs = client.get_jobs(sort_order="desc") - asc_jobs = client.get_jobs(sort_order="asc") + desc_jobs = await client.get_jobs(sort_order="desc") + asc_jobs = await client.get_jobs(sort_order="asc") if len(desc_jobs["jobs"]) >= 2: desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]] @@ -911,31 +911,31 @@ class TestExecution: if len(asc_times) >= 2: assert asc_times == sorted(asc_times), "Asc should be oldest first" - def test_jobs_api_status_filter( + async def test_jobs_api_status_filter( self, client: ComfyClient, builder: GraphBuilder ): """Test jobs API status filtering""" - self._create_history_item(client, builder) + await self._create_history_item(client, builder) - completed_jobs = client.get_jobs(status="completed") + completed_jobs = await client.get_jobs(status="completed") assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history" for job in completed_jobs["jobs"]: assert job["status"] == "completed", "Should only return completed jobs" # Pending jobs are transient - just verify filter doesn't error - pending_jobs = client.get_jobs(status="pending") + pending_jobs = await client.get_jobs(status="pending") for job in pending_jobs["jobs"]: assert job["status"] == "pending", "Should only return pending jobs" - def test_get_job_by_id( + async def test_get_job_by_id( self, client: ComfyClient, builder: GraphBuilder ): """Test getting a single job by ID""" - result = self._create_history_item(client, builder) + result = await self._create_history_item(client, builder) prompt_id = result.get_prompt_id() - job = client.get_job(prompt_id) + job = await client.get_job(prompt_id) assert job is not None, "Should find the job" assert job["id"] == prompt_id, "Job ID should match" assert "outputs" in job, "Single job should include outputs" diff --git a/tests/execution/test_jobs_from_execution.py b/tests/execution/test_jobs_from_execution.py new file mode 100644 index 000000000..045ab56f2 --- /dev/null +++ b/tests/execution/test_jobs_from_execution.py @@ -0,0 +1,135 @@ +# Jobs API tests +import dataclasses + +import pytest + +from comfy.client.aio_client import AsyncRemoteComfyClient +from comfy_execution.graph_utils import GraphBuilder + +@dataclasses.dataclass +class Result: + res:dict + + def get_prompt_id(self): + return self.res["prompt_id"] + + +class TestJobs: + @pytest.fixture + def builder(self, request): + yield GraphBuilder(prefix=request.node.name) + + @pytest.fixture + async def client(self, comfy_background_server_from_config): + async with AsyncRemoteComfyClient(f"http://localhost:{comfy_background_server_from_config[0].port}") as obj: + yield obj + + async def _create_history_item(self, client, builder): + g = GraphBuilder(prefix="offset_test") + input_node = g.node( + "StubImage", content="BLACK", height=32, width=32, batch_size=1 + ) + g.node("SaveImage", images=input_node.out(0)) + await client.queue_prompt_api(g.finalize()) + + async def test_jobs_api_job_structure( + self, client: AsyncRemoteComfyClient, builder: GraphBuilder + ): + """Test that job objects have required fields""" + await self._create_history_item(client, builder) + + jobs_response = await client.get_jobs(status="completed", limit=1) + assert len(jobs_response["jobs"]) > 0, "Should have at least one job" + + job = jobs_response["jobs"][0] + assert "id" in job, "Job should have id" + assert "status" in job, "Job should have status" + assert "create_time" in job, "Job should have create_time" + assert "outputs_count" in job, "Job should have outputs_count" + assert "preview_output" in job, "Job should have preview_output" + + async def test_jobs_api_preview_output_structure( + self, client: AsyncRemoteComfyClient, builder: GraphBuilder + ): + """Test that preview_output has correct structure""" + await self._create_history_item(client, builder) + + jobs_response = await client.get_jobs(status="completed", limit=1) + job = jobs_response["jobs"][0] + + if job["preview_output"] is not None: + preview = job["preview_output"] + assert "filename" in preview, "Preview should have filename" + assert "nodeId" in preview, "Preview should have nodeId" + assert "mediaType" in preview, "Preview should have mediaType" + + async def test_jobs_api_pagination( + self, client: AsyncRemoteComfyClient, builder: GraphBuilder + ): + """Test jobs API pagination""" + for _ in range(5): + await self._create_history_item(client, builder) + + first_page = await client.get_jobs(limit=2, offset=0) + second_page = await client.get_jobs(limit=2, offset=2) + + assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs" + assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs" + + first_ids = {j["id"] for j in first_page["jobs"]} + second_ids = {j["id"] for j in second_page["jobs"]} + assert first_ids.isdisjoint(second_ids), "Pages should have different jobs" + + async def test_jobs_api_sorting( + self, client: AsyncRemoteComfyClient, builder: GraphBuilder + ): + """Test jobs API sorting""" + for _ in range(3): + await self._create_history_item(client, builder) + + desc_jobs = await client.get_jobs(sort_order="desc") + asc_jobs = await client.get_jobs(sort_order="asc") + + if len(desc_jobs["jobs"]) >= 2: + desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]] + asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]] + if len(desc_times) >= 2: + assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first" + if len(asc_times) >= 2: + assert asc_times == sorted(asc_times), "Asc should be oldest first" + + async def test_jobs_api_status_filter( + self, client: AsyncRemoteComfyClient, builder: GraphBuilder + ): + """Test jobs API status filtering""" + await self._create_history_item(client, builder) + + completed_jobs = await client.get_jobs(status="completed") + assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history" + + for job in completed_jobs["jobs"]: + assert job["status"] == "completed", "Should only return completed jobs" + + # Pending jobs are transient - just verify filter doesn't error + pending_jobs = await client.get_jobs(status="pending") + for job in pending_jobs["jobs"]: + assert job["status"] == "pending", "Should only return pending jobs" + + async def test_get_job_by_id( + self, client: AsyncRemoteComfyClient, builder: GraphBuilder + ): + """Test getting a single job by ID""" + result = await self._create_history_item(client, builder) + prompt_id = result.get_prompt_id() + + job = await client.get_job(prompt_id) + assert job is not None, "Should find the job" + assert job["id"] == prompt_id, "Job ID should match" + assert "outputs" in job, "Single job should include outputs" + + async def test_get_job_not_found( + self, client: AsyncRemoteComfyClient, builder: GraphBuilder + ): + """Test getting a non-existent job returns 404""" + job = await client.get_job("nonexistent-job-id") + assert job is None, "Non-existent job should return None"