tests pass

This commit is contained in:
doctorpangloss 2025-12-26 15:20:45 -08:00
parent 1b29ff09da
commit 71698f4099
13 changed files with 314 additions and 86 deletions

View File

@ -414,20 +414,6 @@ paths:
responses:
200:
headers:
Idempotency-Key:
description: |
The API supports idempotency for safely retrying requests without accidentally performing the same operation twice. When creating or updating an object, use an idempotency key. Then, if a connection error occurs, you can safely repeat the request without risk of creating a second object or performing the update twice.
To perform an idempotent request, provide an additional IdempotencyKey element to the request options.
Idempotency works by saving the resulting status code and body of the first request made for any given idempotency key, regardless of whether it succeeds or fails. Subsequent requests with the same key return the same result, including 500 errors.
A client generates an idempotency key, which is a unique key that the server uses to recognize subsequent retries of the same request. How you create unique keys is up to you, but we suggest using V4 UUIDs, or another random string with enough entropy to avoid collisions. Idempotency keys are up to 255 characters long.
You can remove keys from the system automatically after theyre at least 24 hours old. We generate a new request if a key is reused after the original is pruned. The idempotency layer compares incoming parameters to those of the original request and errors if theyre the same to prevent accidental misuse.
example: XFDSF000213
schema:
type: string
Digest:
description: The digest of the request body
example: SHA256=e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
@ -439,6 +425,10 @@ paths:
schema:
type: string
pattern: '^filename=.+'
Location:
description: The relative URL to revisit this request
schema:
type: string
description: |
The content of the last SaveImage node.
content:
@ -467,6 +457,10 @@ paths:
type: array
items:
type: string
prompt_id:
type: string
description:
The ID of the prompt that was queued and executed
outputs:
$ref: "#/components/schemas/Outputs"
example:

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import asyncio
import uuid
from asyncio import AbstractEventLoop
@ -65,7 +66,7 @@ class AsyncRemoteComfyClient:
return headers
@tracer.start_as_current_span("Post Prompt")
async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
async def _post_prompt(self, prompt: PromptDict | dict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
"""
Common method to POST a prompt to a given endpoint.
:param prompt: The prompt to send
@ -101,7 +102,7 @@ class AsyncRemoteComfyClient:
else:
raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
async def queue_prompt_api(self, prompt: PromptDict | dict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
"""
Calls the API to queue a prompt.
:param prompt:
@ -211,3 +212,50 @@ class AsyncRemoteComfyClient:
return response.status, None
# Timeout
return 408, None
async def get_jobs(self, status: Optional[str] = None, workflow_id: Optional[str] = None,
limit: Optional[int] = None, offset: Optional[int] = None,
sort_by: Optional[str] = None, sort_order: Optional[str] = None) -> dict:
"""
List all jobs with filtering, sorting, and pagination.
:param status: Filter by status (comma-separated): pending, in_progress, completed, failed
:param workflow_id: Filter by workflow ID
:param limit: Max items to return
:param offset: Items to skip
:param sort_by: Sort field: created_at (default), execution_duration
:param sort_order: Sort direction: asc, desc (default)
:return: Dictionary containing jobs list and pagination info
"""
params = {}
if status is not None:
params["status"] = status
if workflow_id is not None:
params["workflow_id"] = workflow_id
if limit is not None:
params["limit"] = str(limit)
if offset is not None:
params["offset"] = str(offset)
if sort_by is not None:
params["sort_by"] = sort_by
if sort_order is not None:
params["sort_order"] = sort_order
async with self.session.get(urljoin(self.server_address, "/api/jobs"), params=params) as response:
if response.status == 200:
return await response.json()
else:
raise RuntimeError(f"could not get jobs: {response.status}: {await response.text()}")
async def get_job(self, job_id: str) -> Optional[dict]:
"""
Get a single job by ID.
:param job_id: The job ID
:return: Job dictionary or None if not found
"""
async with self.session.get(urljoin(self.server_address, f"/api/jobs/{job_id}")) as response:
if response.status == 200:
return await response.json()
elif response.status == 404:
return None
else:
raise RuntimeError(f"could not get job: {response.status}: {await response.text()}")

View File

@ -24,6 +24,7 @@ class Output(TypedDict, total=False):
class V1QueuePromptResponse:
urls: List[str]
outputs: dict[str, Output]
prompt_id: str
class ProgressNotification(NamedTuple):

View File

@ -382,19 +382,21 @@ class Comfy:
async def queue_prompt_api(self,
prompt: PromptDict | str | dict,
progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse:
progress_handler: Optional[ExecutorToClientProgress] = None,
prompt_id: Optional[str] = None) -> V1QueuePromptResponse:
"""
Queues a prompt for execution, returning the output when it is complete.
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
:return: a response of URLs for Save-related nodes and the node outputs
"""
prompt_id = prompt_id or str(uuid.uuid4())
if isinstance(prompt, str):
prompt = json.loads(prompt)
if isinstance(prompt, dict):
from ..api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt)
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler)
return V1QueuePromptResponse(urls=[], outputs=outputs)
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler, prompt_id=prompt_id)
return V1QueuePromptResponse(urls=[], outputs=outputs, prompt_id=prompt_id)
def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress:
"""

View File

@ -1193,6 +1193,7 @@ class PromptServer(ExecutorToClientProgress):
filename = main_image["filename"]
digest_headers_ = {
"Digest": f"SHA-256={content_digest}",
"Location": f"/api/v1/prompts/{task_id}"
}
urls_ = []
if len(output_images) == 1:
@ -1224,8 +1225,10 @@ class PromptServer(ExecutorToClientProgress):
headers=digest_headers_,
body=json.dumps({
'urls': urls_,
'outputs': result.outputs
'outputs': result.outputs,
'prompt_id': task_id,
}))
# todo: provide more ways to accept these, or support multiple file returns easily
elif accept == "image/png" or accept == "image/jpeg":
return web.FileResponse(main_image["abs_path"],
headers=digest_headers_)

View File

@ -23,6 +23,7 @@ from .k_diffusion import sampling as k_diffusion_sampling
from .model_base import BaseModel
from .model_management_types import ModelOptions
from .model_patcher import ModelPatcher
from .sampler_helpers import prepare_mask
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES, KSAMPLER_NAMES
from .context_windows import ContextHandlerABC
from .utils import common_upscale, pack_latents, unpack_latents
@ -1068,10 +1069,10 @@ class CFGGuider:
denoise_masks.append(torch.ones(latent_shapes[i]))
for i in range(len(denoise_masks)):
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
denoise_masks[i] = prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
if len(denoise_masks) > 1:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
denoise_mask, _ = pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]

View File

@ -6,49 +6,58 @@ import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple, deque
import comfy.ops
operations=comfy.ops.disable_weight_init
from ..ops import disable_weight_init
operations = disable_weight_init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out, act_func):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = act_func
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f * stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f, n_f * stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
B, T, C, H, W = x.shape
if parallel:
x = x.reshape(B * T, C, H, W)
@ -145,6 +154,7 @@ class TAEHV(nn.Module):
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
act_func, conv(n_f[3], self.image_channels * self.patch_size ** 2),
)
@property
def show_progress_bar(self):
return self._show_progress_bar

View File

@ -8,13 +8,15 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
import comfy.model_management
import comfy.ops
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
from .. import sd1_clip
from ..ldm.modules.attention import optimized_attention_for_device
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer = tokenizer_data.get("spiece_model", None)
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
@ -22,10 +24,14 @@ class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
@dataclass
class XLMRobertaConfig:
@ -44,6 +50,7 @@ class XLMRobertaConfig:
eos_token_id: int = 2
pad_token_id: int = 1
class XLMRobertaEmbeddings(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
@ -61,6 +68,7 @@ class XLMRobertaEmbeddings(nn.Module):
embeddings = embeddings + token_type_embeddings
return embeddings
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base, device=None):
super().__init__()
@ -95,6 +103,7 @@ class RotaryEmbedding(nn.Module):
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MHA(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
@ -122,6 +131,7 @@ class MHA(nn.Module):
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
return self.out_proj(out)
class MLP(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
@ -135,6 +145,7 @@ class MLP(nn.Module):
x = self.fc2(x)
return x
class Block(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
@ -152,17 +163,19 @@ class Block(nn.Module):
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
return hidden_states
class XLMRobertaEncoder(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None):
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
optimized_attention = optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
for layer in self.layers:
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
return hidden_states
class XLMRobertaModel_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
@ -171,7 +184,9 @@ class XLMRobertaModel_(nn.Module):
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=None):
if embeds_info is None:
embeds_info = []
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
x = self.emb_ln(x)
x = self.emb_drop(x)
@ -194,6 +209,7 @@ class XLMRobertaModel_(nn.Module):
# Intermediate output is not yet implemented, use None for placeholder
return sequence_output, None, pooled_output
class XLMRobertaModel(nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@ -210,10 +226,17 @@ class XLMRobertaModel(nn.Module):
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class JinaClip2TextModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)

View File

@ -1,13 +1,16 @@
import torch
import comfy.model_management
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.lumina2
from ..model_management import pick_weight_dtype
from .jina_clip_2 import JinaClip2TextModel, JinaClip2Tokenizer
from .lumina2 import Gemma3_4BTokenizer, Gemma3_4BModel
class NewBieTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
self.gemma = Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
self.jina = JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {}
@ -21,12 +24,15 @@ class NewBieTokenizer:
def state_dict(self):
return {}
class NewBieTEModel(torch.nn.Module):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options=None):
super().__init__()
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
if model_options is None:
model_options = {}
dtype_gemma = pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype, dtype_gemma}
def set_clip_options(self, options):
@ -52,11 +58,15 @@ class NewBieTEModel(torch.nn.Module):
else:
return self.jina.load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class NewBieTEModel_(NewBieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return NewBieTEModel_

View File

@ -1,6 +1,6 @@
import copy
import logging
import comfy.utils
from ..utils import detect_layer_quantization
import torch
from ..transformers_compat import T5TokenizerFast
@ -33,7 +33,7 @@ def t5_xxl_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
quant = detect_layer_quantization(state_dict, prefix)
if quant is not None:
out["t5_quantization_metadata"] = quant

View File

@ -5,7 +5,8 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import comfy.model_management
import torch
import nodes
from comfy.nodes import base_nodes as nodes
class TextEncodeQwenImageEdit(io.ComfyNode):
@classmethod

View File

@ -845,13 +845,13 @@ class TestExecution:
assert numpy.array(result_image).mean() == 0, "Image should be black"
# Jobs API tests
def test_jobs_api_job_structure(
async def test_jobs_api_job_structure(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test that job objects have required fields"""
self._create_history_item(client, builder)
await self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1)
jobs_response = await client.get_jobs(status="completed", limit=1)
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
job = jobs_response["jobs"][0]
@ -861,13 +861,13 @@ class TestExecution:
assert "outputs_count" in job, "Job should have outputs_count"
assert "preview_output" in job, "Job should have preview_output"
def test_jobs_api_preview_output_structure(
async def test_jobs_api_preview_output_structure(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test that preview_output has correct structure"""
self._create_history_item(client, builder)
await self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1)
jobs_response = await client.get_jobs(status="completed", limit=1)
job = jobs_response["jobs"][0]
if job["preview_output"] is not None:
@ -876,15 +876,15 @@ class TestExecution:
assert "nodeId" in preview, "Preview should have nodeId"
assert "mediaType" in preview, "Preview should have mediaType"
def test_jobs_api_pagination(
async def test_jobs_api_pagination(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API pagination"""
for _ in range(5):
self._create_history_item(client, builder)
await self._create_history_item(client, builder)
first_page = client.get_jobs(limit=2, offset=0)
second_page = client.get_jobs(limit=2, offset=2)
first_page = await client.get_jobs(limit=2, offset=0)
second_page = await client.get_jobs(limit=2, offset=2)
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
@ -893,15 +893,15 @@ class TestExecution:
second_ids = {j["id"] for j in second_page["jobs"]}
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
def test_jobs_api_sorting(
async def test_jobs_api_sorting(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API sorting"""
for _ in range(3):
self._create_history_item(client, builder)
await self._create_history_item(client, builder)
desc_jobs = client.get_jobs(sort_order="desc")
asc_jobs = client.get_jobs(sort_order="asc")
desc_jobs = await client.get_jobs(sort_order="desc")
asc_jobs = await client.get_jobs(sort_order="asc")
if len(desc_jobs["jobs"]) >= 2:
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
@ -911,31 +911,31 @@ class TestExecution:
if len(asc_times) >= 2:
assert asc_times == sorted(asc_times), "Asc should be oldest first"
def test_jobs_api_status_filter(
async def test_jobs_api_status_filter(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API status filtering"""
self._create_history_item(client, builder)
await self._create_history_item(client, builder)
completed_jobs = client.get_jobs(status="completed")
completed_jobs = await client.get_jobs(status="completed")
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
for job in completed_jobs["jobs"]:
assert job["status"] == "completed", "Should only return completed jobs"
# Pending jobs are transient - just verify filter doesn't error
pending_jobs = client.get_jobs(status="pending")
pending_jobs = await client.get_jobs(status="pending")
for job in pending_jobs["jobs"]:
assert job["status"] == "pending", "Should only return pending jobs"
def test_get_job_by_id(
async def test_get_job_by_id(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test getting a single job by ID"""
result = self._create_history_item(client, builder)
result = await self._create_history_item(client, builder)
prompt_id = result.get_prompt_id()
job = client.get_job(prompt_id)
job = await client.get_job(prompt_id)
assert job is not None, "Should find the job"
assert job["id"] == prompt_id, "Job ID should match"
assert "outputs" in job, "Single job should include outputs"

View File

@ -0,0 +1,135 @@
# Jobs API tests
import dataclasses
import pytest
from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy_execution.graph_utils import GraphBuilder
@dataclasses.dataclass
class Result:
res:dict
def get_prompt_id(self):
return self.res["prompt_id"]
class TestJobs:
@pytest.fixture
def builder(self, request):
yield GraphBuilder(prefix=request.node.name)
@pytest.fixture
async def client(self, comfy_background_server_from_config):
async with AsyncRemoteComfyClient(f"http://localhost:{comfy_background_server_from_config[0].port}") as obj:
yield obj
async def _create_history_item(self, client, builder):
g = GraphBuilder(prefix="offset_test")
input_node = g.node(
"StubImage", content="BLACK", height=32, width=32, batch_size=1
)
g.node("SaveImage", images=input_node.out(0))
await client.queue_prompt_api(g.finalize())
async def test_jobs_api_job_structure(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test that job objects have required fields"""
await self._create_history_item(client, builder)
jobs_response = await client.get_jobs(status="completed", limit=1)
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
job = jobs_response["jobs"][0]
assert "id" in job, "Job should have id"
assert "status" in job, "Job should have status"
assert "create_time" in job, "Job should have create_time"
assert "outputs_count" in job, "Job should have outputs_count"
assert "preview_output" in job, "Job should have preview_output"
async def test_jobs_api_preview_output_structure(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test that preview_output has correct structure"""
await self._create_history_item(client, builder)
jobs_response = await client.get_jobs(status="completed", limit=1)
job = jobs_response["jobs"][0]
if job["preview_output"] is not None:
preview = job["preview_output"]
assert "filename" in preview, "Preview should have filename"
assert "nodeId" in preview, "Preview should have nodeId"
assert "mediaType" in preview, "Preview should have mediaType"
async def test_jobs_api_pagination(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test jobs API pagination"""
for _ in range(5):
await self._create_history_item(client, builder)
first_page = await client.get_jobs(limit=2, offset=0)
second_page = await client.get_jobs(limit=2, offset=2)
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
first_ids = {j["id"] for j in first_page["jobs"]}
second_ids = {j["id"] for j in second_page["jobs"]}
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
async def test_jobs_api_sorting(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test jobs API sorting"""
for _ in range(3):
await self._create_history_item(client, builder)
desc_jobs = await client.get_jobs(sort_order="desc")
asc_jobs = await client.get_jobs(sort_order="asc")
if len(desc_jobs["jobs"]) >= 2:
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
if len(desc_times) >= 2:
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
if len(asc_times) >= 2:
assert asc_times == sorted(asc_times), "Asc should be oldest first"
async def test_jobs_api_status_filter(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test jobs API status filtering"""
await self._create_history_item(client, builder)
completed_jobs = await client.get_jobs(status="completed")
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
for job in completed_jobs["jobs"]:
assert job["status"] == "completed", "Should only return completed jobs"
# Pending jobs are transient - just verify filter doesn't error
pending_jobs = await client.get_jobs(status="pending")
for job in pending_jobs["jobs"]:
assert job["status"] == "pending", "Should only return pending jobs"
async def test_get_job_by_id(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test getting a single job by ID"""
result = await self._create_history_item(client, builder)
prompt_id = result.get_prompt_id()
job = await client.get_job(prompt_id)
assert job is not None, "Should find the job"
assert job["id"] == prompt_id, "Job ID should match"
assert "outputs" in job, "Single job should include outputs"
async def test_get_job_not_found(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test getting a non-existent job returns 404"""
job = await client.get_job("nonexistent-job-id")
assert job is None, "Non-existent job should return None"