mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 09:27:24 +08:00
Compare commits
9 Commits
7a675aae2d
...
4cff963cc9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cff963cc9 | ||
|
|
c33d26c283 | ||
|
|
f3ea976cba | ||
|
|
d6e94b7dfc | ||
|
|
65c5dba4e3 | ||
|
|
2274a5d3d3 | ||
|
|
74547bf49b | ||
|
|
39e5f74129 | ||
|
|
187e9f03a9 |
12
api_server/utils/query_params.py
Normal file
12
api_server/utils/query_params.py
Normal file
@ -0,0 +1,12 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
def parse_optional_int_query_param(query: Mapping[str, str], name: str) -> int | None:
|
||||
value = query.get(name)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"{name} must be an integer") from exc
|
||||
@ -199,6 +199,9 @@ class FILMNet(nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
|
||||
|
||||
def memory_used_forward(self, shape, dtype):
|
||||
return 1700 * shape[1] * shape[2] * dtype.itemsize
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
"""Pre-compute warp grids for all pyramid levels."""
|
||||
if (H, W) in self._warp_grids:
|
||||
|
||||
@ -74,6 +74,9 @@ class IFNet(nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.encode.cnn0.weight.dtype
|
||||
|
||||
def memory_used_forward(self, shape, dtype):
|
||||
return 300 * shape[1] * shape[2] * dtype.itemsize
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
if (H, W) in self._warp_grids:
|
||||
return
|
||||
|
||||
@ -37,7 +37,7 @@ class FrameInterpolationModelLoader(io.ComfyNode):
|
||||
model = cls._detect_and_load(sd)
|
||||
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
|
||||
model.eval().to(dtype)
|
||||
patcher = comfy.model_patcher.ModelPatcher(
|
||||
patcher = comfy.model_patcher.CoreModelPatcher(
|
||||
model,
|
||||
load_device=model_management.get_torch_device(),
|
||||
offload_device=model_management.unet_offload_device(),
|
||||
@ -98,16 +98,13 @@ class FrameInterpolate(io.ComfyNode):
|
||||
if num_frames < 2 or multiplier < 2:
|
||||
return io.NodeOutput(images)
|
||||
|
||||
model_management.load_model_gpu(interp_model)
|
||||
device = interp_model.load_device
|
||||
dtype = interp_model.model_dtype()
|
||||
inference_model = interp_model.model
|
||||
|
||||
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
activation_mem = H * W * 3 * images.element_size() * 20
|
||||
model_management.free_memory(activation_mem, device)
|
||||
activation_mem = inference_model.memory_used_forward(images.shape, dtype)
|
||||
model_management.load_models_gpu([interp_model], memory_required=activation_mem)
|
||||
align = getattr(inference_model, "pad_align", 1)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
|
||||
# Prepare a single padded frame on device for determining output dimensions
|
||||
def prepare_frame(idx):
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
#config for a1111 ui
|
||||
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
|
||||
|
||||
#a111:
|
||||
#a1111:
|
||||
# base_path: path/to/stable-diffusion-webui/
|
||||
# checkpoints: models/Stable-diffusion
|
||||
# configs: models/Stable-diffusion
|
||||
|
||||
16
server.py
16
server.py
@ -46,6 +46,7 @@ from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from api_server.utils.query_params import parse_optional_int_query_param
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
# Import cache control middleware
|
||||
@ -888,14 +889,15 @@ class PromptServer():
|
||||
|
||||
@routes.get("/history")
|
||||
async def get_history(request):
|
||||
max_items = request.rel_url.query.get("max_items", None)
|
||||
if max_items is not None:
|
||||
max_items = int(max_items)
|
||||
query = request.rel_url.query
|
||||
|
||||
offset = request.rel_url.query.get("offset", None)
|
||||
if offset is not None:
|
||||
offset = int(offset)
|
||||
else:
|
||||
try:
|
||||
max_items = parse_optional_int_query_param(query, "max_items")
|
||||
offset = parse_optional_int_query_param(query, "offset")
|
||||
except ValueError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
|
||||
if offset is None:
|
||||
offset = -1
|
||||
|
||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset))
|
||||
|
||||
39
tests-unit/server/utils/query_params_test.py
Normal file
39
tests-unit/server/utils/query_params_test.py
Normal file
@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
|
||||
from api_server.utils.query_params import parse_optional_int_query_param
|
||||
|
||||
|
||||
def test_parse_optional_int_query_param_returns_none_when_missing():
|
||||
assert parse_optional_int_query_param({}, "offset") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw_value", "expected"),
|
||||
[
|
||||
("0", 0),
|
||||
("5", 5),
|
||||
("-1", -1),
|
||||
],
|
||||
)
|
||||
def test_parse_optional_int_query_param_parses_integers(raw_value, expected):
|
||||
query = {"offset": raw_value}
|
||||
|
||||
assert parse_optional_int_query_param(query, "offset") == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("name", "raw_value"),
|
||||
[
|
||||
("offset", "not-an-integer"),
|
||||
("offset", "1.5"),
|
||||
("offset", ""),
|
||||
("max_items", "not-an-integer"),
|
||||
],
|
||||
)
|
||||
def test_parse_optional_int_query_param_rejects_invalid_integers(name, raw_value):
|
||||
query = {name: raw_value}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parse_optional_int_query_param(query, name)
|
||||
|
||||
assert str(exc_info.value) == f"{name} must be an integer"
|
||||
@ -909,6 +909,20 @@ class TestExecution:
|
||||
|
||||
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
|
||||
|
||||
def test_history_api_rejects_non_integer_max_items(self, client: ComfyClient):
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.get_all_history(max_items="not-an-integer")
|
||||
|
||||
assert exc_info.value.code == 400
|
||||
assert json.loads(exc_info.value.read()) == {"error": "max_items must be an integer"}
|
||||
|
||||
def test_history_api_rejects_non_integer_offset(self, client: ComfyClient):
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.get_all_history(offset="not-an-integer")
|
||||
|
||||
assert exc_info.value.code == 400
|
||||
assert json.loads(exc_info.value.read()) == {"error": "offset must be an integer"}
|
||||
|
||||
# Jobs API tests
|
||||
def test_jobs_api_job_structure(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
|
||||
Loading…
Reference in New Issue
Block a user