mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
tests pass
This commit is contained in:
parent
1b29ff09da
commit
71698f4099
@ -414,20 +414,6 @@ paths:
|
|||||||
responses:
|
responses:
|
||||||
200:
|
200:
|
||||||
headers:
|
headers:
|
||||||
Idempotency-Key:
|
|
||||||
description: |
|
|
||||||
The API supports idempotency for safely retrying requests without accidentally performing the same operation twice. When creating or updating an object, use an idempotency key. Then, if a connection error occurs, you can safely repeat the request without risk of creating a second object or performing the update twice.
|
|
||||||
|
|
||||||
To perform an idempotent request, provide an additional IdempotencyKey element to the request options.
|
|
||||||
|
|
||||||
Idempotency works by saving the resulting status code and body of the first request made for any given idempotency key, regardless of whether it succeeds or fails. Subsequent requests with the same key return the same result, including 500 errors.
|
|
||||||
|
|
||||||
A client generates an idempotency key, which is a unique key that the server uses to recognize subsequent retries of the same request. How you create unique keys is up to you, but we suggest using V4 UUIDs, or another random string with enough entropy to avoid collisions. Idempotency keys are up to 255 characters long.
|
|
||||||
|
|
||||||
You can remove keys from the system automatically after they’re at least 24 hours old. We generate a new request if a key is reused after the original is pruned. The idempotency layer compares incoming parameters to those of the original request and errors if they’re the same to prevent accidental misuse.
|
|
||||||
example: XFDSF000213
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
Digest:
|
Digest:
|
||||||
description: The digest of the request body
|
description: The digest of the request body
|
||||||
example: SHA256=e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
|
example: SHA256=e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
|
||||||
@ -439,6 +425,10 @@ paths:
|
|||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
pattern: '^filename=.+'
|
pattern: '^filename=.+'
|
||||||
|
Location:
|
||||||
|
description: The relative URL to revisit this request
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
description: |
|
description: |
|
||||||
The content of the last SaveImage node.
|
The content of the last SaveImage node.
|
||||||
content:
|
content:
|
||||||
@ -467,6 +457,10 @@ paths:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
|
prompt_id:
|
||||||
|
type: string
|
||||||
|
description:
|
||||||
|
The ID of the prompt that was queued and executed
|
||||||
outputs:
|
outputs:
|
||||||
$ref: "#/components/schemas/Outputs"
|
$ref: "#/components/schemas/Outputs"
|
||||||
example:
|
example:
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
from asyncio import AbstractEventLoop
|
from asyncio import AbstractEventLoop
|
||||||
@ -65,7 +66,7 @@ class AsyncRemoteComfyClient:
|
|||||||
return headers
|
return headers
|
||||||
|
|
||||||
@tracer.start_as_current_span("Post Prompt")
|
@tracer.start_as_current_span("Post Prompt")
|
||||||
async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
|
async def _post_prompt(self, prompt: PromptDict | dict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
|
||||||
"""
|
"""
|
||||||
Common method to POST a prompt to a given endpoint.
|
Common method to POST a prompt to a given endpoint.
|
||||||
:param prompt: The prompt to send
|
:param prompt: The prompt to send
|
||||||
@ -101,7 +102,7 @@ class AsyncRemoteComfyClient:
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
|
raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
|
||||||
|
|
||||||
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
|
async def queue_prompt_api(self, prompt: PromptDict | dict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
|
||||||
"""
|
"""
|
||||||
Calls the API to queue a prompt.
|
Calls the API to queue a prompt.
|
||||||
:param prompt:
|
:param prompt:
|
||||||
@ -211,3 +212,50 @@ class AsyncRemoteComfyClient:
|
|||||||
return response.status, None
|
return response.status, None
|
||||||
# Timeout
|
# Timeout
|
||||||
return 408, None
|
return 408, None
|
||||||
|
|
||||||
|
async def get_jobs(self, status: Optional[str] = None, workflow_id: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None, offset: Optional[int] = None,
|
||||||
|
sort_by: Optional[str] = None, sort_order: Optional[str] = None) -> dict:
|
||||||
|
"""
|
||||||
|
List all jobs with filtering, sorting, and pagination.
|
||||||
|
:param status: Filter by status (comma-separated): pending, in_progress, completed, failed
|
||||||
|
:param workflow_id: Filter by workflow ID
|
||||||
|
:param limit: Max items to return
|
||||||
|
:param offset: Items to skip
|
||||||
|
:param sort_by: Sort field: created_at (default), execution_duration
|
||||||
|
:param sort_order: Sort direction: asc, desc (default)
|
||||||
|
:return: Dictionary containing jobs list and pagination info
|
||||||
|
"""
|
||||||
|
params = {}
|
||||||
|
if status is not None:
|
||||||
|
params["status"] = status
|
||||||
|
if workflow_id is not None:
|
||||||
|
params["workflow_id"] = workflow_id
|
||||||
|
if limit is not None:
|
||||||
|
params["limit"] = str(limit)
|
||||||
|
if offset is not None:
|
||||||
|
params["offset"] = str(offset)
|
||||||
|
if sort_by is not None:
|
||||||
|
params["sort_by"] = sort_by
|
||||||
|
if sort_order is not None:
|
||||||
|
params["sort_order"] = sort_order
|
||||||
|
|
||||||
|
async with self.session.get(urljoin(self.server_address, "/api/jobs"), params=params) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"could not get jobs: {response.status}: {await response.text()}")
|
||||||
|
|
||||||
|
async def get_job(self, job_id: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get a single job by ID.
|
||||||
|
:param job_id: The job ID
|
||||||
|
:return: Job dictionary or None if not found
|
||||||
|
"""
|
||||||
|
async with self.session.get(urljoin(self.server_address, f"/api/jobs/{job_id}")) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
elif response.status == 404:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"could not get job: {response.status}: {await response.text()}")
|
||||||
|
|||||||
@ -24,6 +24,7 @@ class Output(TypedDict, total=False):
|
|||||||
class V1QueuePromptResponse:
|
class V1QueuePromptResponse:
|
||||||
urls: List[str]
|
urls: List[str]
|
||||||
outputs: dict[str, Output]
|
outputs: dict[str, Output]
|
||||||
|
prompt_id: str
|
||||||
|
|
||||||
|
|
||||||
class ProgressNotification(NamedTuple):
|
class ProgressNotification(NamedTuple):
|
||||||
|
|||||||
@ -382,19 +382,21 @@ class Comfy:
|
|||||||
|
|
||||||
async def queue_prompt_api(self,
|
async def queue_prompt_api(self,
|
||||||
prompt: PromptDict | str | dict,
|
prompt: PromptDict | str | dict,
|
||||||
progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse:
|
progress_handler: Optional[ExecutorToClientProgress] = None,
|
||||||
|
prompt_id: Optional[str] = None) -> V1QueuePromptResponse:
|
||||||
"""
|
"""
|
||||||
Queues a prompt for execution, returning the output when it is complete.
|
Queues a prompt for execution, returning the output when it is complete.
|
||||||
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
|
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
|
||||||
:return: a response of URLs for Save-related nodes and the node outputs
|
:return: a response of URLs for Save-related nodes and the node outputs
|
||||||
"""
|
"""
|
||||||
|
prompt_id = prompt_id or str(uuid.uuid4())
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
prompt = json.loads(prompt)
|
prompt = json.loads(prompt)
|
||||||
if isinstance(prompt, dict):
|
if isinstance(prompt, dict):
|
||||||
from ..api.components.schema.prompt import Prompt
|
from ..api.components.schema.prompt import Prompt
|
||||||
prompt = Prompt.validate(prompt)
|
prompt = Prompt.validate(prompt)
|
||||||
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler)
|
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler, prompt_id=prompt_id)
|
||||||
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
return V1QueuePromptResponse(urls=[], outputs=outputs, prompt_id=prompt_id)
|
||||||
|
|
||||||
def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress:
|
def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1193,6 +1193,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
filename = main_image["filename"]
|
filename = main_image["filename"]
|
||||||
digest_headers_ = {
|
digest_headers_ = {
|
||||||
"Digest": f"SHA-256={content_digest}",
|
"Digest": f"SHA-256={content_digest}",
|
||||||
|
"Location": f"/api/v1/prompts/{task_id}"
|
||||||
}
|
}
|
||||||
urls_ = []
|
urls_ = []
|
||||||
if len(output_images) == 1:
|
if len(output_images) == 1:
|
||||||
@ -1224,8 +1225,10 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
headers=digest_headers_,
|
headers=digest_headers_,
|
||||||
body=json.dumps({
|
body=json.dumps({
|
||||||
'urls': urls_,
|
'urls': urls_,
|
||||||
'outputs': result.outputs
|
'outputs': result.outputs,
|
||||||
|
'prompt_id': task_id,
|
||||||
}))
|
}))
|
||||||
|
# todo: provide more ways to accept these, or support multiple file returns easily
|
||||||
elif accept == "image/png" or accept == "image/jpeg":
|
elif accept == "image/png" or accept == "image/jpeg":
|
||||||
return web.FileResponse(main_image["abs_path"],
|
return web.FileResponse(main_image["abs_path"],
|
||||||
headers=digest_headers_)
|
headers=digest_headers_)
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from .k_diffusion import sampling as k_diffusion_sampling
|
|||||||
from .model_base import BaseModel
|
from .model_base import BaseModel
|
||||||
from .model_management_types import ModelOptions
|
from .model_management_types import ModelOptions
|
||||||
from .model_patcher import ModelPatcher
|
from .model_patcher import ModelPatcher
|
||||||
|
from .sampler_helpers import prepare_mask
|
||||||
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES, KSAMPLER_NAMES
|
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES, KSAMPLER_NAMES
|
||||||
from .context_windows import ContextHandlerABC
|
from .context_windows import ContextHandlerABC
|
||||||
from .utils import common_upscale, pack_latents, unpack_latents
|
from .utils import common_upscale, pack_latents, unpack_latents
|
||||||
@ -1068,10 +1069,10 @@ class CFGGuider:
|
|||||||
denoise_masks.append(torch.ones(latent_shapes[i]))
|
denoise_masks.append(torch.ones(latent_shapes[i]))
|
||||||
|
|
||||||
for i in range(len(denoise_masks)):
|
for i in range(len(denoise_masks)):
|
||||||
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
|
denoise_masks[i] = prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
|
||||||
|
|
||||||
if len(denoise_masks) > 1:
|
if len(denoise_masks) > 1:
|
||||||
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
denoise_mask, _ = pack_latents(denoise_masks)
|
||||||
else:
|
else:
|
||||||
denoise_mask = denoise_masks[0]
|
denoise_mask = denoise_masks[0]
|
||||||
|
|
||||||
|
|||||||
@ -6,59 +6,68 @@ import torch.nn.functional as F
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from collections import namedtuple, deque
|
from collections import namedtuple, deque
|
||||||
|
|
||||||
import comfy.ops
|
from ..ops import disable_weight_init
|
||||||
operations=comfy.ops.disable_weight_init
|
|
||||||
|
operations = disable_weight_init
|
||||||
|
|
||||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||||
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
||||||
|
|
||||||
|
|
||||||
def conv(n_in, n_out, **kwargs):
|
def conv(n_in, n_out, **kwargs):
|
||||||
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Clamp(nn.Module):
|
class Clamp(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.tanh(x / 3) * 3
|
return torch.tanh(x / 3) * 3
|
||||||
|
|
||||||
|
|
||||||
class MemBlock(nn.Module):
|
class MemBlock(nn.Module):
|
||||||
def __init__(self, n_in, n_out, act_func):
|
def __init__(self, n_in, n_out, act_func):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
|
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
|
||||||
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||||
self.act = act_func
|
self.act = act_func
|
||||||
|
|
||||||
def forward(self, x, past):
|
def forward(self, x, past):
|
||||||
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
||||||
|
|
||||||
|
|
||||||
class TPool(nn.Module):
|
class TPool(nn.Module):
|
||||||
def __init__(self, n_f, stride):
|
def __init__(self, n_f, stride):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
|
self.conv = operations.Conv2d(n_f * stride, n_f, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
_NT, C, H, W = x.shape
|
_NT, C, H, W = x.shape
|
||||||
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
||||||
|
|
||||||
|
|
||||||
class TGrow(nn.Module):
|
class TGrow(nn.Module):
|
||||||
def __init__(self, n_f, stride):
|
def __init__(self, n_f, stride):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
|
self.conv = operations.Conv2d(n_f, n_f * stride, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
_NT, C, H, W = x.shape
|
_NT, C, H, W = x.shape
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x.reshape(-1, C, H, W)
|
return x.reshape(-1, C, H, W)
|
||||||
|
|
||||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
|
||||||
|
|
||||||
|
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||||
B, T, C, H, W = x.shape
|
B, T, C, H, W = x.shape
|
||||||
if parallel:
|
if parallel:
|
||||||
x = x.reshape(B*T, C, H, W)
|
x = x.reshape(B * T, C, H, W)
|
||||||
# parallel over input timesteps, iterate over blocks
|
# parallel over input timesteps, iterate over blocks
|
||||||
for b in tqdm(model, disable=not show_progress_bar):
|
for b in tqdm(model, disable=not show_progress_bar):
|
||||||
if isinstance(b, MemBlock):
|
if isinstance(b, MemBlock):
|
||||||
BT, C, H, W = x.shape
|
BT, C, H, W = x.shape
|
||||||
T = BT // B
|
T = BT // B
|
||||||
_x = x.reshape(B, T, C, H, W)
|
_x = x.reshape(B, T, C, H, W)
|
||||||
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
|
||||||
x = b(x, mem)
|
x = b(x, mem)
|
||||||
else:
|
else:
|
||||||
x = b(x)
|
x = b(x)
|
||||||
@ -87,25 +96,25 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
|||||||
xt_new = b(xt, mem[i])
|
xt_new = b(xt, mem[i])
|
||||||
mem[i] = xt.detach().clone()
|
mem[i] = xt.detach().clone()
|
||||||
del xt
|
del xt
|
||||||
work_queue.appendleft(TWorkItem(xt_new, i+1))
|
work_queue.appendleft(TWorkItem(xt_new, i + 1))
|
||||||
elif isinstance(b, TPool):
|
elif isinstance(b, TPool):
|
||||||
if mem[i] is None:
|
if mem[i] is None:
|
||||||
mem[i] = []
|
mem[i] = []
|
||||||
mem[i].append(xt.detach().clone())
|
mem[i].append(xt.detach().clone())
|
||||||
if len(mem[i]) == b.stride:
|
if len(mem[i]) == b.stride:
|
||||||
B, C, H, W = xt.shape
|
B, C, H, W = xt.shape
|
||||||
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
|
xt = b(torch.cat(mem[i], 1).view(B * b.stride, C, H, W))
|
||||||
mem[i] = []
|
mem[i] = []
|
||||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
work_queue.appendleft(TWorkItem(xt, i + 1))
|
||||||
elif isinstance(b, TGrow):
|
elif isinstance(b, TGrow):
|
||||||
xt = b(xt)
|
xt = b(xt)
|
||||||
NT, C, H, W = xt.shape
|
NT, C, H, W = xt.shape
|
||||||
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
|
for xt_next in reversed(xt.view(B, b.stride * C, H, W).chunk(b.stride, 1)):
|
||||||
work_queue.appendleft(TWorkItem(xt_next, i+1))
|
work_queue.appendleft(TWorkItem(xt_next, i + 1))
|
||||||
del xt
|
del xt
|
||||||
else:
|
else:
|
||||||
xt = b(xt)
|
xt = b(xt)
|
||||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
work_queue.appendleft(TWorkItem(xt, i + 1))
|
||||||
progress_bar.close()
|
progress_bar.close()
|
||||||
x = torch.stack(out, 1)
|
x = torch.stack(out, 1)
|
||||||
return x
|
return x
|
||||||
@ -122,29 +131,30 @@ class TAEHV(nn.Module):
|
|||||||
self.show_progress_bar = show_progress_bar
|
self.show_progress_bar = show_progress_bar
|
||||||
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
|
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
|
||||||
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
||||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||||
self.patch_size = 2
|
self.patch_size = 2
|
||||||
if self.latent_channels == 32: # HunyuanVideo1.5
|
if self.latent_channels == 32: # HunyuanVideo1.5
|
||||||
act_func = nn.LeakyReLU(0.2, inplace=True)
|
act_func = nn.LeakyReLU(0.2, inplace=True)
|
||||||
else: # HunyuanVideo, Wan 2.1
|
else: # HunyuanVideo, Wan 2.1
|
||||||
act_func = nn.ReLU(inplace=True)
|
act_func = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
self.encoder = nn.Sequential(
|
self.encoder = nn.Sequential(
|
||||||
conv(self.image_channels*self.patch_size**2, 64), act_func,
|
conv(self.image_channels * self.patch_size ** 2, 64), act_func,
|
||||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||||
conv(64, self.latent_channels),
|
conv(64, self.latent_channels),
|
||||||
)
|
)
|
||||||
n_f = [256, 128, 64, 64]
|
n_f = [256, 128, 64, 64]
|
||||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
self.frames_to_trim = 2 ** sum(decoder_time_upscale) - 1
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
|
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
|
||||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||||
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
act_func, conv(n_f[3], self.image_channels * self.patch_size ** 2),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def show_progress_bar(self):
|
def show_progress_bar(self):
|
||||||
return self._show_progress_bar
|
return self._show_progress_bar
|
||||||
|
|||||||
@ -8,13 +8,15 @@ import torch
|
|||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.ops
|
|
||||||
from comfy import sd1_clip
|
|
||||||
from .spiece_tokenizer import SPieceTokenizer
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
|
from .. import sd1_clip
|
||||||
|
from ..ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
|
||||||
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
|
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
|
||||||
@ -22,10 +24,14 @@ class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
|
||||||
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
|
||||||
|
|
||||||
|
|
||||||
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
|
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
|
||||||
@dataclass
|
@dataclass
|
||||||
class XLMRobertaConfig:
|
class XLMRobertaConfig:
|
||||||
@ -44,6 +50,7 @@ class XLMRobertaConfig:
|
|||||||
eos_token_id: int = 2
|
eos_token_id: int = 2
|
||||||
pad_token_id: int = 1
|
pad_token_id: int = 1
|
||||||
|
|
||||||
|
|
||||||
class XLMRobertaEmbeddings(nn.Module):
|
class XLMRobertaEmbeddings(nn.Module):
|
||||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -61,6 +68,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|||||||
embeddings = embeddings + token_type_embeddings
|
embeddings = embeddings + token_type_embeddings
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim, base, device=None):
|
def __init__(self, dim, base, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -95,6 +103,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
class MHA(nn.Module):
|
class MHA(nn.Module):
|
||||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -122,6 +131,7 @@ class MHA(nn.Module):
|
|||||||
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
|
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
|
||||||
return self.out_proj(out)
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -135,6 +145,7 @@ class MLP(nn.Module):
|
|||||||
x = self.fc2(x)
|
x = self.fc2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -152,17 +163,19 @@ class Block(nn.Module):
|
|||||||
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
|
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class XLMRobertaEncoder(nn.Module):
|
class XLMRobertaEncoder(nn.Module):
|
||||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None):
|
def forward(self, hidden_states, attention_mask=None):
|
||||||
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
optimized_attention = optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
|
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class XLMRobertaModel_(nn.Module):
|
class XLMRobertaModel_(nn.Module):
|
||||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -171,7 +184,9 @@ class XLMRobertaModel_(nn.Module):
|
|||||||
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
|
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=None):
|
||||||
|
if embeds_info is None:
|
||||||
|
embeds_info = []
|
||||||
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
|
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
|
||||||
x = self.emb_ln(x)
|
x = self.emb_ln(x)
|
||||||
x = self.emb_drop(x)
|
x = self.emb_drop(x)
|
||||||
@ -194,6 +209,7 @@ class XLMRobertaModel_(nn.Module):
|
|||||||
# Intermediate output is not yet implemented, use None for placeholder
|
# Intermediate output is not yet implemented, use None for placeholder
|
||||||
return sequence_output, None, pooled_output
|
return sequence_output, None, pooled_output
|
||||||
|
|
||||||
|
|
||||||
class XLMRobertaModel(nn.Module):
|
class XLMRobertaModel(nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -210,10 +226,17 @@ class XLMRobertaModel(nn.Module):
|
|||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.model(*args, **kwargs)
|
return self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class JinaClip2TextModel(sd1_clip.SDClipModel):
|
class JinaClip2TextModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
|
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
|
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.model_management
|
from ..model_management import pick_weight_dtype
|
||||||
import comfy.text_encoders.jina_clip_2
|
from .jina_clip_2 import JinaClip2TextModel, JinaClip2Tokenizer
|
||||||
import comfy.text_encoders.lumina2
|
from .lumina2 import Gemma3_4BTokenizer, Gemma3_4BModel
|
||||||
|
|
||||||
|
|
||||||
class NewBieTokenizer:
|
class NewBieTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
if tokenizer_data is None:
|
||||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
tokenizer_data = {}
|
||||||
|
self.gemma = Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
||||||
|
self.jina = JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
|
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
|
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
@ -21,12 +24,15 @@ class NewBieTokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class NewBieTEModel(torch.nn.Module):
|
class NewBieTEModel(torch.nn.Module):
|
||||||
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
|
if model_options is None:
|
||||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
model_options = {}
|
||||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
dtype_gemma = pick_weight_dtype(dtype_gemma, dtype, device)
|
||||||
|
self.gemma = Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
||||||
|
self.jina = JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes = {dtype, dtype_gemma}
|
self.dtypes = {dtype, dtype_gemma}
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
@ -52,11 +58,15 @@ class NewBieTEModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.jina.load_sd(sd)
|
return self.jina.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class NewBieTEModel_(NewBieTEModel):
|
class NewBieTEModel_(NewBieTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
return NewBieTEModel_
|
return NewBieTEModel_
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import comfy.utils
|
from ..utils import detect_layer_quantization
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from ..transformers_compat import T5TokenizerFast
|
from ..transformers_compat import T5TokenizerFast
|
||||||
@ -33,7 +33,7 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
quant = detect_layer_quantization(state_dict, prefix)
|
||||||
if quant is not None:
|
if quant is not None:
|
||||||
out["t5_quantization_metadata"] = quant
|
out["t5_quantization_metadata"] = quant
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,8 @@ from typing_extensions import override
|
|||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
from comfy.nodes import base_nodes as nodes
|
||||||
|
|
||||||
|
|
||||||
class TextEncodeQwenImageEdit(io.ComfyNode):
|
class TextEncodeQwenImageEdit(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -845,13 +845,13 @@ class TestExecution:
|
|||||||
assert numpy.array(result_image).mean() == 0, "Image should be black"
|
assert numpy.array(result_image).mean() == 0, "Image should be black"
|
||||||
|
|
||||||
# Jobs API tests
|
# Jobs API tests
|
||||||
def test_jobs_api_job_structure(
|
async def test_jobs_api_job_structure(
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
):
|
):
|
||||||
"""Test that job objects have required fields"""
|
"""Test that job objects have required fields"""
|
||||||
self._create_history_item(client, builder)
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
jobs_response = client.get_jobs(status="completed", limit=1)
|
jobs_response = await client.get_jobs(status="completed", limit=1)
|
||||||
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
|
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
|
||||||
|
|
||||||
job = jobs_response["jobs"][0]
|
job = jobs_response["jobs"][0]
|
||||||
@ -861,13 +861,13 @@ class TestExecution:
|
|||||||
assert "outputs_count" in job, "Job should have outputs_count"
|
assert "outputs_count" in job, "Job should have outputs_count"
|
||||||
assert "preview_output" in job, "Job should have preview_output"
|
assert "preview_output" in job, "Job should have preview_output"
|
||||||
|
|
||||||
def test_jobs_api_preview_output_structure(
|
async def test_jobs_api_preview_output_structure(
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
):
|
):
|
||||||
"""Test that preview_output has correct structure"""
|
"""Test that preview_output has correct structure"""
|
||||||
self._create_history_item(client, builder)
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
jobs_response = client.get_jobs(status="completed", limit=1)
|
jobs_response = await client.get_jobs(status="completed", limit=1)
|
||||||
job = jobs_response["jobs"][0]
|
job = jobs_response["jobs"][0]
|
||||||
|
|
||||||
if job["preview_output"] is not None:
|
if job["preview_output"] is not None:
|
||||||
@ -876,15 +876,15 @@ class TestExecution:
|
|||||||
assert "nodeId" in preview, "Preview should have nodeId"
|
assert "nodeId" in preview, "Preview should have nodeId"
|
||||||
assert "mediaType" in preview, "Preview should have mediaType"
|
assert "mediaType" in preview, "Preview should have mediaType"
|
||||||
|
|
||||||
def test_jobs_api_pagination(
|
async def test_jobs_api_pagination(
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
):
|
):
|
||||||
"""Test jobs API pagination"""
|
"""Test jobs API pagination"""
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
self._create_history_item(client, builder)
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
first_page = client.get_jobs(limit=2, offset=0)
|
first_page = await client.get_jobs(limit=2, offset=0)
|
||||||
second_page = client.get_jobs(limit=2, offset=2)
|
second_page = await client.get_jobs(limit=2, offset=2)
|
||||||
|
|
||||||
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
|
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
|
||||||
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
|
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
|
||||||
@ -893,15 +893,15 @@ class TestExecution:
|
|||||||
second_ids = {j["id"] for j in second_page["jobs"]}
|
second_ids = {j["id"] for j in second_page["jobs"]}
|
||||||
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
|
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
|
||||||
|
|
||||||
def test_jobs_api_sorting(
|
async def test_jobs_api_sorting(
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
):
|
):
|
||||||
"""Test jobs API sorting"""
|
"""Test jobs API sorting"""
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
self._create_history_item(client, builder)
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
desc_jobs = client.get_jobs(sort_order="desc")
|
desc_jobs = await client.get_jobs(sort_order="desc")
|
||||||
asc_jobs = client.get_jobs(sort_order="asc")
|
asc_jobs = await client.get_jobs(sort_order="asc")
|
||||||
|
|
||||||
if len(desc_jobs["jobs"]) >= 2:
|
if len(desc_jobs["jobs"]) >= 2:
|
||||||
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
|
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
|
||||||
@ -911,31 +911,31 @@ class TestExecution:
|
|||||||
if len(asc_times) >= 2:
|
if len(asc_times) >= 2:
|
||||||
assert asc_times == sorted(asc_times), "Asc should be oldest first"
|
assert asc_times == sorted(asc_times), "Asc should be oldest first"
|
||||||
|
|
||||||
def test_jobs_api_status_filter(
|
async def test_jobs_api_status_filter(
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
):
|
):
|
||||||
"""Test jobs API status filtering"""
|
"""Test jobs API status filtering"""
|
||||||
self._create_history_item(client, builder)
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
completed_jobs = client.get_jobs(status="completed")
|
completed_jobs = await client.get_jobs(status="completed")
|
||||||
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
|
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
|
||||||
|
|
||||||
for job in completed_jobs["jobs"]:
|
for job in completed_jobs["jobs"]:
|
||||||
assert job["status"] == "completed", "Should only return completed jobs"
|
assert job["status"] == "completed", "Should only return completed jobs"
|
||||||
|
|
||||||
# Pending jobs are transient - just verify filter doesn't error
|
# Pending jobs are transient - just verify filter doesn't error
|
||||||
pending_jobs = client.get_jobs(status="pending")
|
pending_jobs = await client.get_jobs(status="pending")
|
||||||
for job in pending_jobs["jobs"]:
|
for job in pending_jobs["jobs"]:
|
||||||
assert job["status"] == "pending", "Should only return pending jobs"
|
assert job["status"] == "pending", "Should only return pending jobs"
|
||||||
|
|
||||||
def test_get_job_by_id(
|
async def test_get_job_by_id(
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
):
|
):
|
||||||
"""Test getting a single job by ID"""
|
"""Test getting a single job by ID"""
|
||||||
result = self._create_history_item(client, builder)
|
result = await self._create_history_item(client, builder)
|
||||||
prompt_id = result.get_prompt_id()
|
prompt_id = result.get_prompt_id()
|
||||||
|
|
||||||
job = client.get_job(prompt_id)
|
job = await client.get_job(prompt_id)
|
||||||
assert job is not None, "Should find the job"
|
assert job is not None, "Should find the job"
|
||||||
assert job["id"] == prompt_id, "Job ID should match"
|
assert job["id"] == prompt_id, "Job ID should match"
|
||||||
assert "outputs" in job, "Single job should include outputs"
|
assert "outputs" in job, "Single job should include outputs"
|
||||||
|
|||||||
135
tests/execution/test_jobs_from_execution.py
Normal file
135
tests/execution/test_jobs_from_execution.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
# Jobs API tests
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Result:
|
||||||
|
res:dict
|
||||||
|
|
||||||
|
def get_prompt_id(self):
|
||||||
|
return self.res["prompt_id"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestJobs:
|
||||||
|
@pytest.fixture
|
||||||
|
def builder(self, request):
|
||||||
|
yield GraphBuilder(prefix=request.node.name)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(self, comfy_background_server_from_config):
|
||||||
|
async with AsyncRemoteComfyClient(f"http://localhost:{comfy_background_server_from_config[0].port}") as obj:
|
||||||
|
yield obj
|
||||||
|
|
||||||
|
async def _create_history_item(self, client, builder):
|
||||||
|
g = GraphBuilder(prefix="offset_test")
|
||||||
|
input_node = g.node(
|
||||||
|
"StubImage", content="BLACK", height=32, width=32, batch_size=1
|
||||||
|
)
|
||||||
|
g.node("SaveImage", images=input_node.out(0))
|
||||||
|
await client.queue_prompt_api(g.finalize())
|
||||||
|
|
||||||
|
async def test_jobs_api_job_structure(
|
||||||
|
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test that job objects have required fields"""
|
||||||
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
jobs_response = await client.get_jobs(status="completed", limit=1)
|
||||||
|
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
|
||||||
|
|
||||||
|
job = jobs_response["jobs"][0]
|
||||||
|
assert "id" in job, "Job should have id"
|
||||||
|
assert "status" in job, "Job should have status"
|
||||||
|
assert "create_time" in job, "Job should have create_time"
|
||||||
|
assert "outputs_count" in job, "Job should have outputs_count"
|
||||||
|
assert "preview_output" in job, "Job should have preview_output"
|
||||||
|
|
||||||
|
async def test_jobs_api_preview_output_structure(
|
||||||
|
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test that preview_output has correct structure"""
|
||||||
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
jobs_response = await client.get_jobs(status="completed", limit=1)
|
||||||
|
job = jobs_response["jobs"][0]
|
||||||
|
|
||||||
|
if job["preview_output"] is not None:
|
||||||
|
preview = job["preview_output"]
|
||||||
|
assert "filename" in preview, "Preview should have filename"
|
||||||
|
assert "nodeId" in preview, "Preview should have nodeId"
|
||||||
|
assert "mediaType" in preview, "Preview should have mediaType"
|
||||||
|
|
||||||
|
async def test_jobs_api_pagination(
|
||||||
|
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test jobs API pagination"""
|
||||||
|
for _ in range(5):
|
||||||
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
first_page = await client.get_jobs(limit=2, offset=0)
|
||||||
|
second_page = await client.get_jobs(limit=2, offset=2)
|
||||||
|
|
||||||
|
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
|
||||||
|
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
|
||||||
|
|
||||||
|
first_ids = {j["id"] for j in first_page["jobs"]}
|
||||||
|
second_ids = {j["id"] for j in second_page["jobs"]}
|
||||||
|
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
|
||||||
|
|
||||||
|
async def test_jobs_api_sorting(
|
||||||
|
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test jobs API sorting"""
|
||||||
|
for _ in range(3):
|
||||||
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
desc_jobs = await client.get_jobs(sort_order="desc")
|
||||||
|
asc_jobs = await client.get_jobs(sort_order="asc")
|
||||||
|
|
||||||
|
if len(desc_jobs["jobs"]) >= 2:
|
||||||
|
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
|
||||||
|
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
|
||||||
|
if len(desc_times) >= 2:
|
||||||
|
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
|
||||||
|
if len(asc_times) >= 2:
|
||||||
|
assert asc_times == sorted(asc_times), "Asc should be oldest first"
|
||||||
|
|
||||||
|
async def test_jobs_api_status_filter(
|
||||||
|
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test jobs API status filtering"""
|
||||||
|
await self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
completed_jobs = await client.get_jobs(status="completed")
|
||||||
|
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
|
||||||
|
|
||||||
|
for job in completed_jobs["jobs"]:
|
||||||
|
assert job["status"] == "completed", "Should only return completed jobs"
|
||||||
|
|
||||||
|
# Pending jobs are transient - just verify filter doesn't error
|
||||||
|
pending_jobs = await client.get_jobs(status="pending")
|
||||||
|
for job in pending_jobs["jobs"]:
|
||||||
|
assert job["status"] == "pending", "Should only return pending jobs"
|
||||||
|
|
||||||
|
async def test_get_job_by_id(
|
||||||
|
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test getting a single job by ID"""
|
||||||
|
result = await self._create_history_item(client, builder)
|
||||||
|
prompt_id = result.get_prompt_id()
|
||||||
|
|
||||||
|
job = await client.get_job(prompt_id)
|
||||||
|
assert job is not None, "Should find the job"
|
||||||
|
assert job["id"] == prompt_id, "Job ID should match"
|
||||||
|
assert "outputs" in job, "Single job should include outputs"
|
||||||
|
|
||||||
|
async def test_get_job_not_found(
|
||||||
|
self, client: AsyncRemoteComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test getting a non-existent job returns 404"""
|
||||||
|
job = await client.get_job("nonexistent-job-id")
|
||||||
|
assert job is None, "Non-existent job should return None"
|
||||||
Loading…
Reference in New Issue
Block a user