mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-06 12:10:51 +08:00
tests pass
This commit is contained in:
parent
1b29ff09da
commit
71698f4099
@ -414,20 +414,6 @@ paths:
|
||||
responses:
|
||||
200:
|
||||
headers:
|
||||
Idempotency-Key:
|
||||
description: |
|
||||
The API supports idempotency for safely retrying requests without accidentally performing the same operation twice. When creating or updating an object, use an idempotency key. Then, if a connection error occurs, you can safely repeat the request without risk of creating a second object or performing the update twice.
|
||||
|
||||
To perform an idempotent request, provide an additional IdempotencyKey element to the request options.
|
||||
|
||||
Idempotency works by saving the resulting status code and body of the first request made for any given idempotency key, regardless of whether it succeeds or fails. Subsequent requests with the same key return the same result, including 500 errors.
|
||||
|
||||
A client generates an idempotency key, which is a unique key that the server uses to recognize subsequent retries of the same request. How you create unique keys is up to you, but we suggest using V4 UUIDs, or another random string with enough entropy to avoid collisions. Idempotency keys are up to 255 characters long.
|
||||
|
||||
You can remove keys from the system automatically after they’re at least 24 hours old. We generate a new request if a key is reused after the original is pruned. The idempotency layer compares incoming parameters to those of the original request and errors if they’re the same to prevent accidental misuse.
|
||||
example: XFDSF000213
|
||||
schema:
|
||||
type: string
|
||||
Digest:
|
||||
description: The digest of the request body
|
||||
example: SHA256=e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
|
||||
@ -439,6 +425,10 @@ paths:
|
||||
schema:
|
||||
type: string
|
||||
pattern: '^filename=.+'
|
||||
Location:
|
||||
description: The relative URL to revisit this request
|
||||
schema:
|
||||
type: string
|
||||
description: |
|
||||
The content of the last SaveImage node.
|
||||
content:
|
||||
@ -467,6 +457,10 @@ paths:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
prompt_id:
|
||||
type: string
|
||||
description:
|
||||
The ID of the prompt that was queued and executed
|
||||
outputs:
|
||||
$ref: "#/components/schemas/Outputs"
|
||||
example:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import uuid
|
||||
from asyncio import AbstractEventLoop
|
||||
@ -65,7 +66,7 @@ class AsyncRemoteComfyClient:
|
||||
return headers
|
||||
|
||||
@tracer.start_as_current_span("Post Prompt")
|
||||
async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
|
||||
async def _post_prompt(self, prompt: PromptDict | dict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
|
||||
"""
|
||||
Common method to POST a prompt to a given endpoint.
|
||||
:param prompt: The prompt to send
|
||||
@ -101,7 +102,7 @@ class AsyncRemoteComfyClient:
|
||||
else:
|
||||
raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
|
||||
|
||||
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
|
||||
async def queue_prompt_api(self, prompt: PromptDict | dict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
|
||||
"""
|
||||
Calls the API to queue a prompt.
|
||||
:param prompt:
|
||||
@ -211,3 +212,50 @@ class AsyncRemoteComfyClient:
|
||||
return response.status, None
|
||||
# Timeout
|
||||
return 408, None
|
||||
|
||||
async def get_jobs(self, status: Optional[str] = None, workflow_id: Optional[str] = None,
|
||||
limit: Optional[int] = None, offset: Optional[int] = None,
|
||||
sort_by: Optional[str] = None, sort_order: Optional[str] = None) -> dict:
|
||||
"""
|
||||
List all jobs with filtering, sorting, and pagination.
|
||||
:param status: Filter by status (comma-separated): pending, in_progress, completed, failed
|
||||
:param workflow_id: Filter by workflow ID
|
||||
:param limit: Max items to return
|
||||
:param offset: Items to skip
|
||||
:param sort_by: Sort field: created_at (default), execution_duration
|
||||
:param sort_order: Sort direction: asc, desc (default)
|
||||
:return: Dictionary containing jobs list and pagination info
|
||||
"""
|
||||
params = {}
|
||||
if status is not None:
|
||||
params["status"] = status
|
||||
if workflow_id is not None:
|
||||
params["workflow_id"] = workflow_id
|
||||
if limit is not None:
|
||||
params["limit"] = str(limit)
|
||||
if offset is not None:
|
||||
params["offset"] = str(offset)
|
||||
if sort_by is not None:
|
||||
params["sort_by"] = sort_by
|
||||
if sort_order is not None:
|
||||
params["sort_order"] = sort_order
|
||||
|
||||
async with self.session.get(urljoin(self.server_address, "/api/jobs"), params=params) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
raise RuntimeError(f"could not get jobs: {response.status}: {await response.text()}")
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get a single job by ID.
|
||||
:param job_id: The job ID
|
||||
:return: Job dictionary or None if not found
|
||||
"""
|
||||
async with self.session.get(urljoin(self.server_address, f"/api/jobs/{job_id}")) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
elif response.status == 404:
|
||||
return None
|
||||
else:
|
||||
raise RuntimeError(f"could not get job: {response.status}: {await response.text()}")
|
||||
|
||||
@ -24,6 +24,7 @@ class Output(TypedDict, total=False):
|
||||
class V1QueuePromptResponse:
|
||||
urls: List[str]
|
||||
outputs: dict[str, Output]
|
||||
prompt_id: str
|
||||
|
||||
|
||||
class ProgressNotification(NamedTuple):
|
||||
|
||||
@ -382,19 +382,21 @@ class Comfy:
|
||||
|
||||
async def queue_prompt_api(self,
|
||||
prompt: PromptDict | str | dict,
|
||||
progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse:
|
||||
progress_handler: Optional[ExecutorToClientProgress] = None,
|
||||
prompt_id: Optional[str] = None) -> V1QueuePromptResponse:
|
||||
"""
|
||||
Queues a prompt for execution, returning the output when it is complete.
|
||||
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
|
||||
:return: a response of URLs for Save-related nodes and the node outputs
|
||||
"""
|
||||
prompt_id = prompt_id or str(uuid.uuid4())
|
||||
if isinstance(prompt, str):
|
||||
prompt = json.loads(prompt)
|
||||
if isinstance(prompt, dict):
|
||||
from ..api.components.schema.prompt import Prompt
|
||||
prompt = Prompt.validate(prompt)
|
||||
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler)
|
||||
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
||||
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler, prompt_id=prompt_id)
|
||||
return V1QueuePromptResponse(urls=[], outputs=outputs, prompt_id=prompt_id)
|
||||
|
||||
def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress:
|
||||
"""
|
||||
|
||||
@ -1193,6 +1193,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
filename = main_image["filename"]
|
||||
digest_headers_ = {
|
||||
"Digest": f"SHA-256={content_digest}",
|
||||
"Location": f"/api/v1/prompts/{task_id}"
|
||||
}
|
||||
urls_ = []
|
||||
if len(output_images) == 1:
|
||||
@ -1224,8 +1225,10 @@ class PromptServer(ExecutorToClientProgress):
|
||||
headers=digest_headers_,
|
||||
body=json.dumps({
|
||||
'urls': urls_,
|
||||
'outputs': result.outputs
|
||||
'outputs': result.outputs,
|
||||
'prompt_id': task_id,
|
||||
}))
|
||||
# todo: provide more ways to accept these, or support multiple file returns easily
|
||||
elif accept == "image/png" or accept == "image/jpeg":
|
||||
return web.FileResponse(main_image["abs_path"],
|
||||
headers=digest_headers_)
|
||||
|
||||
@ -23,6 +23,7 @@ from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .model_base import BaseModel
|
||||
from .model_management_types import ModelOptions
|
||||
from .model_patcher import ModelPatcher
|
||||
from .sampler_helpers import prepare_mask
|
||||
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES, KSAMPLER_NAMES
|
||||
from .context_windows import ContextHandlerABC
|
||||
from .utils import common_upscale, pack_latents, unpack_latents
|
||||
@ -1068,10 +1069,10 @@ class CFGGuider:
|
||||
denoise_masks.append(torch.ones(latent_shapes[i]))
|
||||
|
||||
for i in range(len(denoise_masks)):
|
||||
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
|
||||
denoise_masks[i] = prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
|
||||
|
||||
if len(denoise_masks) > 1:
|
||||
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||
denoise_mask, _ = pack_latents(denoise_masks)
|
||||
else:
|
||||
denoise_mask = denoise_masks[0]
|
||||
|
||||
|
||||
@ -6,59 +6,68 @@ import torch.nn.functional as F
|
||||
from tqdm.auto import tqdm
|
||||
from collections import namedtuple, deque
|
||||
|
||||
import comfy.ops
|
||||
operations=comfy.ops.disable_weight_init
|
||||
from ..ops import disable_weight_init
|
||||
|
||||
operations = disable_weight_init
|
||||
|
||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
||||
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
|
||||
class MemBlock(nn.Module):
|
||||
def __init__(self, n_in, n_out, act_func):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
|
||||
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.act = act_func
|
||||
|
||||
def forward(self, x, past):
|
||||
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
||||
|
||||
|
||||
class TPool(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
|
||||
self.conv = operations.Conv2d(n_f * stride, n_f, 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
||||
|
||||
|
||||
class TGrow(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
|
||||
self.conv = operations.Conv2d(n_f, n_f * stride, 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
x = self.conv(x)
|
||||
return x.reshape(-1, C, H, W)
|
||||
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
B, T, C, H, W = x.shape
|
||||
if parallel:
|
||||
x = x.reshape(B*T, C, H, W)
|
||||
x = x.reshape(B * T, C, H, W)
|
||||
# parallel over input timesteps, iterate over blocks
|
||||
for b in tqdm(model, disable=not show_progress_bar):
|
||||
if isinstance(b, MemBlock):
|
||||
BT, C, H, W = x.shape
|
||||
T = BT // B
|
||||
_x = x.reshape(B, T, C, H, W)
|
||||
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
||||
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
|
||||
x = b(x, mem)
|
||||
else:
|
||||
x = b(x)
|
||||
@ -87,25 +96,25 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
xt_new = b(xt, mem[i])
|
||||
mem[i] = xt.detach().clone()
|
||||
del xt
|
||||
work_queue.appendleft(TWorkItem(xt_new, i+1))
|
||||
work_queue.appendleft(TWorkItem(xt_new, i + 1))
|
||||
elif isinstance(b, TPool):
|
||||
if mem[i] is None:
|
||||
mem[i] = []
|
||||
mem[i].append(xt.detach().clone())
|
||||
if len(mem[i]) == b.stride:
|
||||
B, C, H, W = xt.shape
|
||||
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
|
||||
xt = b(torch.cat(mem[i], 1).view(B * b.stride, C, H, W))
|
||||
mem[i] = []
|
||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
||||
work_queue.appendleft(TWorkItem(xt, i + 1))
|
||||
elif isinstance(b, TGrow):
|
||||
xt = b(xt)
|
||||
NT, C, H, W = xt.shape
|
||||
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
|
||||
work_queue.appendleft(TWorkItem(xt_next, i+1))
|
||||
for xt_next in reversed(xt.view(B, b.stride * C, H, W).chunk(b.stride, 1)):
|
||||
work_queue.appendleft(TWorkItem(xt_next, i + 1))
|
||||
del xt
|
||||
else:
|
||||
xt = b(xt)
|
||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
||||
work_queue.appendleft(TWorkItem(xt, i + 1))
|
||||
progress_bar.close()
|
||||
x = torch.stack(out, 1)
|
||||
return x
|
||||
@ -122,29 +131,30 @@ class TAEHV(nn.Module):
|
||||
self.show_progress_bar = show_progress_bar
|
||||
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
|
||||
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||
self.patch_size = 2
|
||||
if self.latent_channels == 32: # HunyuanVideo1.5
|
||||
if self.latent_channels == 32: # HunyuanVideo1.5
|
||||
act_func = nn.LeakyReLU(0.2, inplace=True)
|
||||
else: # HunyuanVideo, Wan 2.1
|
||||
else: # HunyuanVideo, Wan 2.1
|
||||
act_func = nn.ReLU(inplace=True)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
conv(self.image_channels*self.patch_size**2, 64), act_func,
|
||||
conv(self.image_channels * self.patch_size ** 2, 64), act_func,
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
conv(64, self.latent_channels),
|
||||
)
|
||||
n_f = [256, 128, 64, 64]
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
|
||||
self.decoder = nn.Sequential(
|
||||
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
||||
act_func, conv(n_f[3], self.image_channels * self.patch_size ** 2),
|
||||
)
|
||||
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
@ -8,13 +8,15 @@ import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
from .. import sd1_clip
|
||||
from ..ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
|
||||
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = {}
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
|
||||
@ -22,10 +24,14 @@ class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
|
||||
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = {}
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
|
||||
|
||||
|
||||
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
|
||||
@dataclass
|
||||
class XLMRobertaConfig:
|
||||
@ -44,6 +50,7 @@ class XLMRobertaConfig:
|
||||
eos_token_id: int = 2
|
||||
pad_token_id: int = 1
|
||||
|
||||
|
||||
class XLMRobertaEmbeddings(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
@ -61,6 +68,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
||||
embeddings = embeddings + token_type_embeddings
|
||||
return embeddings
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, base, device=None):
|
||||
super().__init__()
|
||||
@ -95,6 +103,7 @@ class RotaryEmbedding(nn.Module):
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class MHA(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
@ -122,6 +131,7 @@ class MHA(nn.Module):
|
||||
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
@ -135,6 +145,7 @@ class MLP(nn.Module):
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
@ -152,17 +163,19 @@ class Block(nn.Module):
|
||||
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XLMRobertaEncoder(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None):
|
||||
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
||||
optimized_attention = optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XLMRobertaModel_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
@ -171,7 +184,9 @@ class XLMRobertaModel_(nn.Module):
|
||||
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=None):
|
||||
if embeds_info is None:
|
||||
embeds_info = []
|
||||
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
|
||||
x = self.emb_ln(x)
|
||||
x = self.emb_drop(x)
|
||||
@ -194,6 +209,7 @@ class XLMRobertaModel_(nn.Module):
|
||||
# Intermediate output is not yet implemented, use None for placeholder
|
||||
return sequence_output, None, pooled_output
|
||||
|
||||
|
||||
class XLMRobertaModel(nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
@ -210,10 +226,17 @@ class XLMRobertaModel(nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
|
||||
class JinaClip2TextModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
|
||||
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
|
||||
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.lumina2
|
||||
from ..model_management import pick_weight_dtype
|
||||
from .jina_clip_2 import JinaClip2TextModel, JinaClip2Tokenizer
|
||||
from .lumina2 import Gemma3_4BTokenizer, Gemma3_4BModel
|
||||
|
||||
|
||||
class NewBieTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = {}
|
||||
self.gemma = Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
||||
self.jina = JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
@ -21,12 +24,15 @@ class NewBieTokenizer:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
|
||||
class NewBieTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
|
||||
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options=None):
|
||||
super().__init__()
|
||||
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
|
||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
dtype_gemma = pick_weight_dtype(dtype_gemma, dtype, device)
|
||||
self.gemma = Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
||||
self.jina = JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = {dtype, dtype_gemma}
|
||||
|
||||
def set_clip_options(self, options):
|
||||
@ -52,11 +58,15 @@ class NewBieTEModel(torch.nn.Module):
|
||||
else:
|
||||
return self.jina.load_sd(sd)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class NewBieTEModel_(NewBieTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
return NewBieTEModel_
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import copy
|
||||
import logging
|
||||
import comfy.utils
|
||||
from ..utils import detect_layer_quantization
|
||||
|
||||
import torch
|
||||
from ..transformers_compat import T5TokenizerFast
|
||||
@ -33,7 +33,7 @@ def t5_xxl_detect(state_dict, prefix=""):
|
||||
if t5_key in state_dict:
|
||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
quant = detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
out["t5_quantization_metadata"] = quant
|
||||
|
||||
|
||||
@ -5,7 +5,8 @@ from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import comfy.model_management
|
||||
import torch
|
||||
import nodes
|
||||
from comfy.nodes import base_nodes as nodes
|
||||
|
||||
|
||||
class TextEncodeQwenImageEdit(io.ComfyNode):
|
||||
@classmethod
|
||||
|
||||
@ -845,13 +845,13 @@ class TestExecution:
|
||||
assert numpy.array(result_image).mean() == 0, "Image should be black"
|
||||
|
||||
# Jobs API tests
|
||||
def test_jobs_api_job_structure(
|
||||
async def test_jobs_api_job_structure(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test that job objects have required fields"""
|
||||
self._create_history_item(client, builder)
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
jobs_response = client.get_jobs(status="completed", limit=1)
|
||||
jobs_response = await client.get_jobs(status="completed", limit=1)
|
||||
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
|
||||
|
||||
job = jobs_response["jobs"][0]
|
||||
@ -861,13 +861,13 @@ class TestExecution:
|
||||
assert "outputs_count" in job, "Job should have outputs_count"
|
||||
assert "preview_output" in job, "Job should have preview_output"
|
||||
|
||||
def test_jobs_api_preview_output_structure(
|
||||
async def test_jobs_api_preview_output_structure(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test that preview_output has correct structure"""
|
||||
self._create_history_item(client, builder)
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
jobs_response = client.get_jobs(status="completed", limit=1)
|
||||
jobs_response = await client.get_jobs(status="completed", limit=1)
|
||||
job = jobs_response["jobs"][0]
|
||||
|
||||
if job["preview_output"] is not None:
|
||||
@ -876,15 +876,15 @@ class TestExecution:
|
||||
assert "nodeId" in preview, "Preview should have nodeId"
|
||||
assert "mediaType" in preview, "Preview should have mediaType"
|
||||
|
||||
def test_jobs_api_pagination(
|
||||
async def test_jobs_api_pagination(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test jobs API pagination"""
|
||||
for _ in range(5):
|
||||
self._create_history_item(client, builder)
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
first_page = client.get_jobs(limit=2, offset=0)
|
||||
second_page = client.get_jobs(limit=2, offset=2)
|
||||
first_page = await client.get_jobs(limit=2, offset=0)
|
||||
second_page = await client.get_jobs(limit=2, offset=2)
|
||||
|
||||
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
|
||||
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
|
||||
@ -893,15 +893,15 @@ class TestExecution:
|
||||
second_ids = {j["id"] for j in second_page["jobs"]}
|
||||
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
|
||||
|
||||
def test_jobs_api_sorting(
|
||||
async def test_jobs_api_sorting(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test jobs API sorting"""
|
||||
for _ in range(3):
|
||||
self._create_history_item(client, builder)
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
desc_jobs = client.get_jobs(sort_order="desc")
|
||||
asc_jobs = client.get_jobs(sort_order="asc")
|
||||
desc_jobs = await client.get_jobs(sort_order="desc")
|
||||
asc_jobs = await client.get_jobs(sort_order="asc")
|
||||
|
||||
if len(desc_jobs["jobs"]) >= 2:
|
||||
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
|
||||
@ -911,31 +911,31 @@ class TestExecution:
|
||||
if len(asc_times) >= 2:
|
||||
assert asc_times == sorted(asc_times), "Asc should be oldest first"
|
||||
|
||||
def test_jobs_api_status_filter(
|
||||
async def test_jobs_api_status_filter(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test jobs API status filtering"""
|
||||
self._create_history_item(client, builder)
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
completed_jobs = client.get_jobs(status="completed")
|
||||
completed_jobs = await client.get_jobs(status="completed")
|
||||
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
|
||||
|
||||
for job in completed_jobs["jobs"]:
|
||||
assert job["status"] == "completed", "Should only return completed jobs"
|
||||
|
||||
# Pending jobs are transient - just verify filter doesn't error
|
||||
pending_jobs = client.get_jobs(status="pending")
|
||||
pending_jobs = await client.get_jobs(status="pending")
|
||||
for job in pending_jobs["jobs"]:
|
||||
assert job["status"] == "pending", "Should only return pending jobs"
|
||||
|
||||
def test_get_job_by_id(
|
||||
async def test_get_job_by_id(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test getting a single job by ID"""
|
||||
result = self._create_history_item(client, builder)
|
||||
result = await self._create_history_item(client, builder)
|
||||
prompt_id = result.get_prompt_id()
|
||||
|
||||
job = client.get_job(prompt_id)
|
||||
job = await client.get_job(prompt_id)
|
||||
assert job is not None, "Should find the job"
|
||||
assert job["id"] == prompt_id, "Job ID should match"
|
||||
assert "outputs" in job, "Single job should include outputs"
|
||||
|
||||
135
tests/execution/test_jobs_from_execution.py
Normal file
135
tests/execution/test_jobs_from_execution.py
Normal file
@ -0,0 +1,135 @@
|
||||
# Jobs API tests
|
||||
import dataclasses
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Result:
|
||||
res:dict
|
||||
|
||||
def get_prompt_id(self):
|
||||
return self.res["prompt_id"]
|
||||
|
||||
|
||||
class TestJobs:
|
||||
@pytest.fixture
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, comfy_background_server_from_config):
|
||||
async with AsyncRemoteComfyClient(f"http://localhost:{comfy_background_server_from_config[0].port}") as obj:
|
||||
yield obj
|
||||
|
||||
async def _create_history_item(self, client, builder):
|
||||
g = GraphBuilder(prefix="offset_test")
|
||||
input_node = g.node(
|
||||
"StubImage", content="BLACK", height=32, width=32, batch_size=1
|
||||
)
|
||||
g.node("SaveImage", images=input_node.out(0))
|
||||
await client.queue_prompt_api(g.finalize())
|
||||
|
||||
async def test_jobs_api_job_structure(
|
||||
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test that job objects have required fields"""
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
jobs_response = await client.get_jobs(status="completed", limit=1)
|
||||
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
|
||||
|
||||
job = jobs_response["jobs"][0]
|
||||
assert "id" in job, "Job should have id"
|
||||
assert "status" in job, "Job should have status"
|
||||
assert "create_time" in job, "Job should have create_time"
|
||||
assert "outputs_count" in job, "Job should have outputs_count"
|
||||
assert "preview_output" in job, "Job should have preview_output"
|
||||
|
||||
async def test_jobs_api_preview_output_structure(
|
||||
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test that preview_output has correct structure"""
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
jobs_response = await client.get_jobs(status="completed", limit=1)
|
||||
job = jobs_response["jobs"][0]
|
||||
|
||||
if job["preview_output"] is not None:
|
||||
preview = job["preview_output"]
|
||||
assert "filename" in preview, "Preview should have filename"
|
||||
assert "nodeId" in preview, "Preview should have nodeId"
|
||||
assert "mediaType" in preview, "Preview should have mediaType"
|
||||
|
||||
async def test_jobs_api_pagination(
|
||||
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test jobs API pagination"""
|
||||
for _ in range(5):
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
first_page = await client.get_jobs(limit=2, offset=0)
|
||||
second_page = await client.get_jobs(limit=2, offset=2)
|
||||
|
||||
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
|
||||
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
|
||||
|
||||
first_ids = {j["id"] for j in first_page["jobs"]}
|
||||
second_ids = {j["id"] for j in second_page["jobs"]}
|
||||
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
|
||||
|
||||
async def test_jobs_api_sorting(
|
||||
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test jobs API sorting"""
|
||||
for _ in range(3):
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
desc_jobs = await client.get_jobs(sort_order="desc")
|
||||
asc_jobs = await client.get_jobs(sort_order="asc")
|
||||
|
||||
if len(desc_jobs["jobs"]) >= 2:
|
||||
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
|
||||
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
|
||||
if len(desc_times) >= 2:
|
||||
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
|
||||
if len(asc_times) >= 2:
|
||||
assert asc_times == sorted(asc_times), "Asc should be oldest first"
|
||||
|
||||
async def test_jobs_api_status_filter(
|
||||
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test jobs API status filtering"""
|
||||
await self._create_history_item(client, builder)
|
||||
|
||||
completed_jobs = await client.get_jobs(status="completed")
|
||||
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
|
||||
|
||||
for job in completed_jobs["jobs"]:
|
||||
assert job["status"] == "completed", "Should only return completed jobs"
|
||||
|
||||
# Pending jobs are transient - just verify filter doesn't error
|
||||
pending_jobs = await client.get_jobs(status="pending")
|
||||
for job in pending_jobs["jobs"]:
|
||||
assert job["status"] == "pending", "Should only return pending jobs"
|
||||
|
||||
async def test_get_job_by_id(
|
||||
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test getting a single job by ID"""
|
||||
result = await self._create_history_item(client, builder)
|
||||
prompt_id = result.get_prompt_id()
|
||||
|
||||
job = await client.get_job(prompt_id)
|
||||
assert job is not None, "Should find the job"
|
||||
assert job["id"] == prompt_id, "Job ID should match"
|
||||
assert "outputs" in job, "Single job should include outputs"
|
||||
|
||||
async def test_get_job_not_found(
|
||||
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test getting a non-existent job returns 404"""
|
||||
job = await client.get_job("nonexistent-job-id")
|
||||
assert job is None, "Non-existent job should return None"
|
||||
Loading…
Reference in New Issue
Block a user