tests pass

This commit is contained in:
doctorpangloss 2025-12-26 15:20:45 -08:00
parent 1b29ff09da
commit 71698f4099
13 changed files with 314 additions and 86 deletions

View File

@ -414,20 +414,6 @@ paths:
responses: responses:
200: 200:
headers: headers:
Idempotency-Key:
description: |
The API supports idempotency for safely retrying requests without accidentally performing the same operation twice. When creating or updating an object, use an idempotency key. Then, if a connection error occurs, you can safely repeat the request without risk of creating a second object or performing the update twice.
To perform an idempotent request, provide an additional IdempotencyKey element to the request options.
Idempotency works by saving the resulting status code and body of the first request made for any given idempotency key, regardless of whether it succeeds or fails. Subsequent requests with the same key return the same result, including 500 errors.
A client generates an idempotency key, which is a unique key that the server uses to recognize subsequent retries of the same request. How you create unique keys is up to you, but we suggest using V4 UUIDs, or another random string with enough entropy to avoid collisions. Idempotency keys are up to 255 characters long.
You can remove keys from the system automatically after theyre at least 24 hours old. We generate a new request if a key is reused after the original is pruned. The idempotency layer compares incoming parameters to those of the original request and errors if theyre the same to prevent accidental misuse.
example: XFDSF000213
schema:
type: string
Digest: Digest:
description: The digest of the request body description: The digest of the request body
example: SHA256=e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3 example: SHA256=e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
@ -439,6 +425,10 @@ paths:
schema: schema:
type: string type: string
pattern: '^filename=.+' pattern: '^filename=.+'
Location:
description: The relative URL to revisit this request
schema:
type: string
description: | description: |
The content of the last SaveImage node. The content of the last SaveImage node.
content: content:
@ -467,6 +457,10 @@ paths:
type: array type: array
items: items:
type: string type: string
prompt_id:
type: string
description:
The ID of the prompt that was queued and executed
outputs: outputs:
$ref: "#/components/schemas/Outputs" $ref: "#/components/schemas/Outputs"
example: example:

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import asyncio import asyncio
import uuid import uuid
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
@ -65,7 +66,7 @@ class AsyncRemoteComfyClient:
return headers return headers
@tracer.start_as_current_span("Post Prompt") @tracer.start_as_current_span("Post Prompt")
async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse: async def _post_prompt(self, prompt: PromptDict | dict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
""" """
Common method to POST a prompt to a given endpoint. Common method to POST a prompt to a given endpoint.
:param prompt: The prompt to send :param prompt: The prompt to send
@ -101,7 +102,7 @@ class AsyncRemoteComfyClient:
else: else:
raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}") raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse: async def queue_prompt_api(self, prompt: PromptDict | dict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
""" """
Calls the API to queue a prompt. Calls the API to queue a prompt.
:param prompt: :param prompt:
@ -211,3 +212,50 @@ class AsyncRemoteComfyClient:
return response.status, None return response.status, None
# Timeout # Timeout
return 408, None return 408, None
async def get_jobs(self, status: Optional[str] = None, workflow_id: Optional[str] = None,
limit: Optional[int] = None, offset: Optional[int] = None,
sort_by: Optional[str] = None, sort_order: Optional[str] = None) -> dict:
"""
List all jobs with filtering, sorting, and pagination.
:param status: Filter by status (comma-separated): pending, in_progress, completed, failed
:param workflow_id: Filter by workflow ID
:param limit: Max items to return
:param offset: Items to skip
:param sort_by: Sort field: created_at (default), execution_duration
:param sort_order: Sort direction: asc, desc (default)
:return: Dictionary containing jobs list and pagination info
"""
params = {}
if status is not None:
params["status"] = status
if workflow_id is not None:
params["workflow_id"] = workflow_id
if limit is not None:
params["limit"] = str(limit)
if offset is not None:
params["offset"] = str(offset)
if sort_by is not None:
params["sort_by"] = sort_by
if sort_order is not None:
params["sort_order"] = sort_order
async with self.session.get(urljoin(self.server_address, "/api/jobs"), params=params) as response:
if response.status == 200:
return await response.json()
else:
raise RuntimeError(f"could not get jobs: {response.status}: {await response.text()}")
async def get_job(self, job_id: str) -> Optional[dict]:
"""
Get a single job by ID.
:param job_id: The job ID
:return: Job dictionary or None if not found
"""
async with self.session.get(urljoin(self.server_address, f"/api/jobs/{job_id}")) as response:
if response.status == 200:
return await response.json()
elif response.status == 404:
return None
else:
raise RuntimeError(f"could not get job: {response.status}: {await response.text()}")

View File

@ -24,6 +24,7 @@ class Output(TypedDict, total=False):
class V1QueuePromptResponse: class V1QueuePromptResponse:
urls: List[str] urls: List[str]
outputs: dict[str, Output] outputs: dict[str, Output]
prompt_id: str
class ProgressNotification(NamedTuple): class ProgressNotification(NamedTuple):

View File

@ -382,19 +382,21 @@ class Comfy:
async def queue_prompt_api(self, async def queue_prompt_api(self,
prompt: PromptDict | str | dict, prompt: PromptDict | str | dict,
progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse: progress_handler: Optional[ExecutorToClientProgress] = None,
prompt_id: Optional[str] = None) -> V1QueuePromptResponse:
""" """
Queues a prompt for execution, returning the output when it is complete. Queues a prompt for execution, returning the output when it is complete.
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt :param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
:return: a response of URLs for Save-related nodes and the node outputs :return: a response of URLs for Save-related nodes and the node outputs
""" """
prompt_id = prompt_id or str(uuid.uuid4())
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = json.loads(prompt) prompt = json.loads(prompt)
if isinstance(prompt, dict): if isinstance(prompt, dict):
from ..api.components.schema.prompt import Prompt from ..api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt) prompt = Prompt.validate(prompt)
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler) outputs = await self.queue_prompt(prompt, progress_handler=progress_handler, prompt_id=prompt_id)
return V1QueuePromptResponse(urls=[], outputs=outputs) return V1QueuePromptResponse(urls=[], outputs=outputs, prompt_id=prompt_id)
def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress: def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress:
""" """

View File

@ -1193,6 +1193,7 @@ class PromptServer(ExecutorToClientProgress):
filename = main_image["filename"] filename = main_image["filename"]
digest_headers_ = { digest_headers_ = {
"Digest": f"SHA-256={content_digest}", "Digest": f"SHA-256={content_digest}",
"Location": f"/api/v1/prompts/{task_id}"
} }
urls_ = [] urls_ = []
if len(output_images) == 1: if len(output_images) == 1:
@ -1224,8 +1225,10 @@ class PromptServer(ExecutorToClientProgress):
headers=digest_headers_, headers=digest_headers_,
body=json.dumps({ body=json.dumps({
'urls': urls_, 'urls': urls_,
'outputs': result.outputs 'outputs': result.outputs,
'prompt_id': task_id,
})) }))
# todo: provide more ways to accept these, or support multiple file returns easily
elif accept == "image/png" or accept == "image/jpeg": elif accept == "image/png" or accept == "image/jpeg":
return web.FileResponse(main_image["abs_path"], return web.FileResponse(main_image["abs_path"],
headers=digest_headers_) headers=digest_headers_)

View File

@ -23,6 +23,7 @@ from .k_diffusion import sampling as k_diffusion_sampling
from .model_base import BaseModel from .model_base import BaseModel
from .model_management_types import ModelOptions from .model_management_types import ModelOptions
from .model_patcher import ModelPatcher from .model_patcher import ModelPatcher
from .sampler_helpers import prepare_mask
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES, KSAMPLER_NAMES from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES, KSAMPLER_NAMES
from .context_windows import ContextHandlerABC from .context_windows import ContextHandlerABC
from .utils import common_upscale, pack_latents, unpack_latents from .utils import common_upscale, pack_latents, unpack_latents
@ -1068,10 +1069,10 @@ class CFGGuider:
denoise_masks.append(torch.ones(latent_shapes[i])) denoise_masks.append(torch.ones(latent_shapes[i]))
for i in range(len(denoise_masks)): for i in range(len(denoise_masks)):
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device) denoise_masks[i] = prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
if len(denoise_masks) > 1: if len(denoise_masks) > 1:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks) denoise_mask, _ = pack_latents(denoise_masks)
else: else:
denoise_mask = denoise_masks[0] denoise_mask = denoise_masks[0]

View File

@ -6,59 +6,68 @@ import torch.nn.functional as F
from tqdm.auto import tqdm from tqdm.auto import tqdm
from collections import namedtuple, deque from collections import namedtuple, deque
import comfy.ops from ..ops import disable_weight_init
operations=comfy.ops.disable_weight_init
operations = disable_weight_init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs): def conv(n_in, n_out, **kwargs):
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs) return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module): class Clamp(nn.Module):
def forward(self, x): def forward(self, x):
return torch.tanh(x / 3) * 3 return torch.tanh(x / 3) * 3
class MemBlock(nn.Module): class MemBlock(nn.Module):
def __init__(self, n_in, n_out, act_func): def __init__(self, n_in, n_out, act_func):
super().__init__() super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out)) self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = act_func self.act = act_func
def forward(self, x, past): def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module): class TPool(nn.Module):
def __init__(self, n_f, stride): def __init__(self, n_f, stride):
super().__init__() super().__init__()
self.stride = stride self.stride = stride
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False) self.conv = operations.Conv2d(n_f * stride, n_f, 1, bias=False)
def forward(self, x): def forward(self, x):
_NT, C, H, W = x.shape _NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W)) return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module): class TGrow(nn.Module):
def __init__(self, n_f, stride): def __init__(self, n_f, stride):
super().__init__() super().__init__()
self.stride = stride self.stride = stride
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False) self.conv = operations.Conv2d(n_f, n_f * stride, 1, bias=False)
def forward(self, x): def forward(self, x):
_NT, C, H, W = x.shape _NT, C, H, W = x.shape
x = self.conv(x) x = self.conv(x)
return x.reshape(-1, C, H, W) return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
B, T, C, H, W = x.shape B, T, C, H, W = x.shape
if parallel: if parallel:
x = x.reshape(B*T, C, H, W) x = x.reshape(B * T, C, H, W)
# parallel over input timesteps, iterate over blocks # parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar): for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock): if isinstance(b, MemBlock):
BT, C, H, W = x.shape BT, C, H, W = x.shape
T = BT // B T = BT // B
_x = x.reshape(B, T, C, H, W) _x = x.reshape(B, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
x = b(x, mem) x = b(x, mem)
else: else:
x = b(x) x = b(x)
@ -87,25 +96,25 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
xt_new = b(xt, mem[i]) xt_new = b(xt, mem[i])
mem[i] = xt.detach().clone() mem[i] = xt.detach().clone()
del xt del xt
work_queue.appendleft(TWorkItem(xt_new, i+1)) work_queue.appendleft(TWorkItem(xt_new, i + 1))
elif isinstance(b, TPool): elif isinstance(b, TPool):
if mem[i] is None: if mem[i] is None:
mem[i] = [] mem[i] = []
mem[i].append(xt.detach().clone()) mem[i].append(xt.detach().clone())
if len(mem[i]) == b.stride: if len(mem[i]) == b.stride:
B, C, H, W = xt.shape B, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W)) xt = b(torch.cat(mem[i], 1).view(B * b.stride, C, H, W))
mem[i] = [] mem[i] = []
work_queue.appendleft(TWorkItem(xt, i+1)) work_queue.appendleft(TWorkItem(xt, i + 1))
elif isinstance(b, TGrow): elif isinstance(b, TGrow):
xt = b(xt) xt = b(xt)
NT, C, H, W = xt.shape NT, C, H, W = xt.shape
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)): for xt_next in reversed(xt.view(B, b.stride * C, H, W).chunk(b.stride, 1)):
work_queue.appendleft(TWorkItem(xt_next, i+1)) work_queue.appendleft(TWorkItem(xt_next, i + 1))
del xt del xt
else: else:
xt = b(xt) xt = b(xt)
work_queue.appendleft(TWorkItem(xt, i+1)) work_queue.appendleft(TWorkItem(xt, i + 1))
progress_bar.close() progress_bar.close()
x = torch.stack(out, 1) x = torch.stack(out, 1)
return x return x
@ -122,29 +131,30 @@ class TAEHV(nn.Module):
self.show_progress_bar = show_progress_bar self.show_progress_bar = show_progress_bar
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x) self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
self.patch_size = 2 self.patch_size = 2
if self.latent_channels == 32: # HunyuanVideo1.5 if self.latent_channels == 32: # HunyuanVideo1.5
act_func = nn.LeakyReLU(0.2, inplace=True) act_func = nn.LeakyReLU(0.2, inplace=True)
else: # HunyuanVideo, Wan 2.1 else: # HunyuanVideo, Wan 2.1
act_func = nn.ReLU(inplace=True) act_func = nn.ReLU(inplace=True)
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
conv(self.image_channels*self.patch_size**2, 64), act_func, conv(self.image_channels * self.patch_size ** 2, 64), act_func,
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
conv(64, self.latent_channels), conv(64, self.latent_channels),
) )
n_f = [256, 128, 64, 64] n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), act_func, Clamp(), conv(self.latent_channels, n_f[0]), act_func,
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
act_func, conv(n_f[3], self.image_channels*self.patch_size**2), act_func, conv(n_f[3], self.image_channels * self.patch_size ** 2),
) )
@property @property
def show_progress_bar(self): def show_progress_bar(self):
return self._show_progress_bar return self._show_progress_bar

View File

@ -8,13 +8,15 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
import comfy.model_management
import comfy.ops
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer from .spiece_tokenizer import SPieceTokenizer
from .. import sd1_clip
from ..ldm.modules.attention import optimized_attention_for_device
class JinaClip2Tokenizer(sd1_clip.SDTokenizer): class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer = tokenizer_data.get("spiece_model", None) tokenizer = tokenizer_data.get("spiece_model", None)
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192 # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data) super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
@ -22,10 +24,14 @@ class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer): class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2") super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json # https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
@dataclass @dataclass
class XLMRobertaConfig: class XLMRobertaConfig:
@ -44,6 +50,7 @@ class XLMRobertaConfig:
eos_token_id: int = 2 eos_token_id: int = 2
pad_token_id: int = 1 pad_token_id: int = 1
class XLMRobertaEmbeddings(nn.Module): class XLMRobertaEmbeddings(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
@ -61,6 +68,7 @@ class XLMRobertaEmbeddings(nn.Module):
embeddings = embeddings + token_type_embeddings embeddings = embeddings + token_type_embeddings
return embeddings return embeddings
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim, base, device=None): def __init__(self, dim, base, device=None):
super().__init__() super().__init__()
@ -95,6 +103,7 @@ class RotaryEmbedding(nn.Module):
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
class MHA(nn.Module): class MHA(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
@ -122,6 +131,7 @@ class MHA(nn.Module):
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True) out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
return self.out_proj(out) return self.out_proj(out)
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
@ -135,6 +145,7 @@ class MLP(nn.Module):
x = self.fc2(x) x = self.fc2(x)
return x return x
class Block(nn.Module): class Block(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
@ -152,17 +163,19 @@ class Block(nn.Module):
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states) hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
return hidden_states return hidden_states
class XLMRobertaEncoder(nn.Module): class XLMRobertaEncoder(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None): def forward(self, hidden_states, attention_mask=None):
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) optimized_attention = optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
for layer in self.layers: for layer in self.layers:
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention) hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
return hidden_states return hidden_states
class XLMRobertaModel_(nn.Module): class XLMRobertaModel_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
@ -171,7 +184,9 @@ class XLMRobertaModel_(nn.Module):
self.emb_drop = nn.Dropout(config.hidden_dropout_prob) self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops) self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=None):
if embeds_info is None:
embeds_info = []
x = self.embeddings(input_ids=input_ids, embeddings=embeds) x = self.embeddings(input_ids=input_ids, embeddings=embeds)
x = self.emb_ln(x) x = self.emb_ln(x)
x = self.emb_drop(x) x = self.emb_drop(x)
@ -194,6 +209,7 @@ class XLMRobertaModel_(nn.Module):
# Intermediate output is not yet implemented, use None for placeholder # Intermediate output is not yet implemented, use None for placeholder
return sequence_output, None, pooled_output return sequence_output, None, pooled_output
class XLMRobertaModel(nn.Module): class XLMRobertaModel(nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
@ -210,10 +226,17 @@ class XLMRobertaModel(nn.Module):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.model(*args, **kwargs) return self.model(*args, **kwargs)
class JinaClip2TextModel(sd1_clip.SDClipModel): class JinaClip2TextModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel): class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options) super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)

View File

@ -1,15 +1,18 @@
import torch import torch
import comfy.model_management from ..model_management import pick_weight_dtype
import comfy.text_encoders.jina_clip_2 from .jina_clip_2 import JinaClip2TextModel, JinaClip2Tokenizer
import comfy.text_encoders.lumina2 from .lumina2 import Gemma3_4BTokenizer, Gemma3_4BModel
class NewBieTokenizer: class NewBieTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) if tokenizer_data is None:
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) tokenizer_data = {}
self.gemma = Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
self.jina = JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {} out = {}
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs) out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs) out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
@ -21,12 +24,15 @@ class NewBieTokenizer:
def state_dict(self): def state_dict(self):
return {} return {}
class NewBieTEModel(torch.nn.Module): class NewBieTEModel(torch.nn.Module):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options=None):
super().__init__() super().__init__()
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) if model_options is None:
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) model_options = {}
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) dtype_gemma = pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype, dtype_gemma} self.dtypes = {dtype, dtype_gemma}
def set_clip_options(self, options): def set_clip_options(self, options):
@ -52,11 +58,15 @@ class NewBieTEModel(torch.nn.Module):
else: else:
return self.jina.load_sd(sd) return self.jina.load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None): def te(dtype_llama=None, llama_quantization_metadata=None):
class NewBieTEModel_(NewBieTEModel): class NewBieTEModel_(NewBieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_quantization_metadata is not None: if llama_quantization_metadata is not None:
model_options = model_options.copy() model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return NewBieTEModel_ return NewBieTEModel_

View File

@ -1,6 +1,6 @@
import copy import copy
import logging import logging
import comfy.utils from ..utils import detect_layer_quantization
import torch import torch
from ..transformers_compat import T5TokenizerFast from ..transformers_compat import T5TokenizerFast
@ -33,7 +33,7 @@ def t5_xxl_detect(state_dict, prefix=""):
if t5_key in state_dict: if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype out["dtype_t5"] = state_dict[t5_key].dtype
quant = comfy.utils.detect_layer_quantization(state_dict, prefix) quant = detect_layer_quantization(state_dict, prefix)
if quant is not None: if quant is not None:
out["t5_quantization_metadata"] = quant out["t5_quantization_metadata"] = quant

View File

@ -5,7 +5,8 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import comfy.model_management import comfy.model_management
import torch import torch
import nodes from comfy.nodes import base_nodes as nodes
class TextEncodeQwenImageEdit(io.ComfyNode): class TextEncodeQwenImageEdit(io.ComfyNode):
@classmethod @classmethod

View File

@ -845,13 +845,13 @@ class TestExecution:
assert numpy.array(result_image).mean() == 0, "Image should be black" assert numpy.array(result_image).mean() == 0, "Image should be black"
# Jobs API tests # Jobs API tests
def test_jobs_api_job_structure( async def test_jobs_api_job_structure(
self, client: ComfyClient, builder: GraphBuilder self, client: ComfyClient, builder: GraphBuilder
): ):
"""Test that job objects have required fields""" """Test that job objects have required fields"""
self._create_history_item(client, builder) await self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1) jobs_response = await client.get_jobs(status="completed", limit=1)
assert len(jobs_response["jobs"]) > 0, "Should have at least one job" assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
job = jobs_response["jobs"][0] job = jobs_response["jobs"][0]
@ -861,13 +861,13 @@ class TestExecution:
assert "outputs_count" in job, "Job should have outputs_count" assert "outputs_count" in job, "Job should have outputs_count"
assert "preview_output" in job, "Job should have preview_output" assert "preview_output" in job, "Job should have preview_output"
def test_jobs_api_preview_output_structure( async def test_jobs_api_preview_output_structure(
self, client: ComfyClient, builder: GraphBuilder self, client: ComfyClient, builder: GraphBuilder
): ):
"""Test that preview_output has correct structure""" """Test that preview_output has correct structure"""
self._create_history_item(client, builder) await self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1) jobs_response = await client.get_jobs(status="completed", limit=1)
job = jobs_response["jobs"][0] job = jobs_response["jobs"][0]
if job["preview_output"] is not None: if job["preview_output"] is not None:
@ -876,15 +876,15 @@ class TestExecution:
assert "nodeId" in preview, "Preview should have nodeId" assert "nodeId" in preview, "Preview should have nodeId"
assert "mediaType" in preview, "Preview should have mediaType" assert "mediaType" in preview, "Preview should have mediaType"
def test_jobs_api_pagination( async def test_jobs_api_pagination(
self, client: ComfyClient, builder: GraphBuilder self, client: ComfyClient, builder: GraphBuilder
): ):
"""Test jobs API pagination""" """Test jobs API pagination"""
for _ in range(5): for _ in range(5):
self._create_history_item(client, builder) await self._create_history_item(client, builder)
first_page = client.get_jobs(limit=2, offset=0) first_page = await client.get_jobs(limit=2, offset=0)
second_page = client.get_jobs(limit=2, offset=2) second_page = await client.get_jobs(limit=2, offset=2)
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs" assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs" assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
@ -893,15 +893,15 @@ class TestExecution:
second_ids = {j["id"] for j in second_page["jobs"]} second_ids = {j["id"] for j in second_page["jobs"]}
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs" assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
def test_jobs_api_sorting( async def test_jobs_api_sorting(
self, client: ComfyClient, builder: GraphBuilder self, client: ComfyClient, builder: GraphBuilder
): ):
"""Test jobs API sorting""" """Test jobs API sorting"""
for _ in range(3): for _ in range(3):
self._create_history_item(client, builder) await self._create_history_item(client, builder)
desc_jobs = client.get_jobs(sort_order="desc") desc_jobs = await client.get_jobs(sort_order="desc")
asc_jobs = client.get_jobs(sort_order="asc") asc_jobs = await client.get_jobs(sort_order="asc")
if len(desc_jobs["jobs"]) >= 2: if len(desc_jobs["jobs"]) >= 2:
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]] desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
@ -911,31 +911,31 @@ class TestExecution:
if len(asc_times) >= 2: if len(asc_times) >= 2:
assert asc_times == sorted(asc_times), "Asc should be oldest first" assert asc_times == sorted(asc_times), "Asc should be oldest first"
def test_jobs_api_status_filter( async def test_jobs_api_status_filter(
self, client: ComfyClient, builder: GraphBuilder self, client: ComfyClient, builder: GraphBuilder
): ):
"""Test jobs API status filtering""" """Test jobs API status filtering"""
self._create_history_item(client, builder) await self._create_history_item(client, builder)
completed_jobs = client.get_jobs(status="completed") completed_jobs = await client.get_jobs(status="completed")
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history" assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
for job in completed_jobs["jobs"]: for job in completed_jobs["jobs"]:
assert job["status"] == "completed", "Should only return completed jobs" assert job["status"] == "completed", "Should only return completed jobs"
# Pending jobs are transient - just verify filter doesn't error # Pending jobs are transient - just verify filter doesn't error
pending_jobs = client.get_jobs(status="pending") pending_jobs = await client.get_jobs(status="pending")
for job in pending_jobs["jobs"]: for job in pending_jobs["jobs"]:
assert job["status"] == "pending", "Should only return pending jobs" assert job["status"] == "pending", "Should only return pending jobs"
def test_get_job_by_id( async def test_get_job_by_id(
self, client: ComfyClient, builder: GraphBuilder self, client: ComfyClient, builder: GraphBuilder
): ):
"""Test getting a single job by ID""" """Test getting a single job by ID"""
result = self._create_history_item(client, builder) result = await self._create_history_item(client, builder)
prompt_id = result.get_prompt_id() prompt_id = result.get_prompt_id()
job = client.get_job(prompt_id) job = await client.get_job(prompt_id)
assert job is not None, "Should find the job" assert job is not None, "Should find the job"
assert job["id"] == prompt_id, "Job ID should match" assert job["id"] == prompt_id, "Job ID should match"
assert "outputs" in job, "Single job should include outputs" assert "outputs" in job, "Single job should include outputs"

View File

@ -0,0 +1,135 @@
# Jobs API tests
import dataclasses
import pytest
from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy_execution.graph_utils import GraphBuilder
@dataclasses.dataclass
class Result:
res:dict
def get_prompt_id(self):
return self.res["prompt_id"]
class TestJobs:
@pytest.fixture
def builder(self, request):
yield GraphBuilder(prefix=request.node.name)
@pytest.fixture
async def client(self, comfy_background_server_from_config):
async with AsyncRemoteComfyClient(f"http://localhost:{comfy_background_server_from_config[0].port}") as obj:
yield obj
async def _create_history_item(self, client, builder):
g = GraphBuilder(prefix="offset_test")
input_node = g.node(
"StubImage", content="BLACK", height=32, width=32, batch_size=1
)
g.node("SaveImage", images=input_node.out(0))
await client.queue_prompt_api(g.finalize())
async def test_jobs_api_job_structure(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test that job objects have required fields"""
await self._create_history_item(client, builder)
jobs_response = await client.get_jobs(status="completed", limit=1)
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
job = jobs_response["jobs"][0]
assert "id" in job, "Job should have id"
assert "status" in job, "Job should have status"
assert "create_time" in job, "Job should have create_time"
assert "outputs_count" in job, "Job should have outputs_count"
assert "preview_output" in job, "Job should have preview_output"
async def test_jobs_api_preview_output_structure(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test that preview_output has correct structure"""
await self._create_history_item(client, builder)
jobs_response = await client.get_jobs(status="completed", limit=1)
job = jobs_response["jobs"][0]
if job["preview_output"] is not None:
preview = job["preview_output"]
assert "filename" in preview, "Preview should have filename"
assert "nodeId" in preview, "Preview should have nodeId"
assert "mediaType" in preview, "Preview should have mediaType"
async def test_jobs_api_pagination(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test jobs API pagination"""
for _ in range(5):
await self._create_history_item(client, builder)
first_page = await client.get_jobs(limit=2, offset=0)
second_page = await client.get_jobs(limit=2, offset=2)
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
first_ids = {j["id"] for j in first_page["jobs"]}
second_ids = {j["id"] for j in second_page["jobs"]}
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
async def test_jobs_api_sorting(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test jobs API sorting"""
for _ in range(3):
await self._create_history_item(client, builder)
desc_jobs = await client.get_jobs(sort_order="desc")
asc_jobs = await client.get_jobs(sort_order="asc")
if len(desc_jobs["jobs"]) >= 2:
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
if len(desc_times) >= 2:
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
if len(asc_times) >= 2:
assert asc_times == sorted(asc_times), "Asc should be oldest first"
async def test_jobs_api_status_filter(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test jobs API status filtering"""
await self._create_history_item(client, builder)
completed_jobs = await client.get_jobs(status="completed")
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
for job in completed_jobs["jobs"]:
assert job["status"] == "completed", "Should only return completed jobs"
# Pending jobs are transient - just verify filter doesn't error
pending_jobs = await client.get_jobs(status="pending")
for job in pending_jobs["jobs"]:
assert job["status"] == "pending", "Should only return pending jobs"
async def test_get_job_by_id(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test getting a single job by ID"""
result = await self._create_history_item(client, builder)
prompt_id = result.get_prompt_id()
job = await client.get_job(prompt_id)
assert job is not None, "Should find the job"
assert job["id"] == prompt_id, "Job ID should match"
assert "outputs" in job, "Single job should include outputs"
async def test_get_job_not_found(
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
):
"""Test getting a non-existent job returns 404"""
job = await client.get_job("nonexistent-job-id")
assert job is None, "Non-existent job should return None"