mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-20 11:20:18 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
4a1e3ee925
@ -1,7 +1,6 @@
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
|
from folder_paths import folder_names_and_paths
|
||||||
from api_server.services.file_service import FileService
|
|
||||||
from api_server.services.terminal_service import TerminalService
|
from api_server.services.terminal_service import TerminalService
|
||||||
import app.logger
|
import app.logger
|
||||||
|
|
||||||
@ -15,26 +14,10 @@ class InternalRoutes:
|
|||||||
def __init__(self, prompt_server):
|
def __init__(self, prompt_server):
|
||||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||||
self._app: Optional[web.Application] = None
|
self._app: Optional[web.Application] = None
|
||||||
self.file_service = FileService({
|
|
||||||
"models": models_dir,
|
|
||||||
"user": user_directory,
|
|
||||||
"output": output_directory
|
|
||||||
})
|
|
||||||
self.prompt_server = prompt_server
|
self.prompt_server = prompt_server
|
||||||
self.terminal_service = TerminalService(prompt_server)
|
self.terminal_service = TerminalService(prompt_server)
|
||||||
|
|
||||||
def setup_routes(self):
|
def setup_routes(self):
|
||||||
@self.routes.get('/files')
|
|
||||||
async def list_files(request):
|
|
||||||
directory_key = request.query.get('directory', '')
|
|
||||||
try:
|
|
||||||
file_list = self.file_service.list_files(directory_key)
|
|
||||||
return web.json_response({"files": file_list})
|
|
||||||
except ValueError as e:
|
|
||||||
return web.json_response({"error": str(e)}, status=400)
|
|
||||||
except Exception as e:
|
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
|
||||||
|
|
||||||
@self.routes.get('/logs')
|
@self.routes.get('/logs')
|
||||||
async def get_logs(request):
|
async def get_logs(request):
|
||||||
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
|
||||||
|
|||||||
@ -1,13 +0,0 @@
|
|||||||
from typing import Dict, List, Optional
|
|
||||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
|
||||||
|
|
||||||
class FileService:
|
|
||||||
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
|
||||||
self.allowed_directories: Dict[str, str] = allowed_directories
|
|
||||||
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
|
||||||
|
|
||||||
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
|
||||||
if directory_key not in self.allowed_directories:
|
|
||||||
raise ValueError("Invalid directory key")
|
|
||||||
directory_path: str = self.allowed_directories[directory_key]
|
|
||||||
return self.file_system_ops.walk_directory(directory_path)
|
|
||||||
@ -352,25 +352,6 @@ class FinalLayer(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RopeEmbedder:
|
|
||||||
def __init__(
|
|
||||||
self, theta: float = 10000.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512)
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.theta = theta
|
|
||||||
self.axes_dims = axes_dims
|
|
||||||
self.axes_lens = axes_lens
|
|
||||||
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
|
||||||
|
|
||||||
def __call__(self, ids: torch.Tensor):
|
|
||||||
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
|
|
||||||
result = []
|
|
||||||
for i in range(len(self.axes_dims)):
|
|
||||||
index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64)
|
|
||||||
result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
|
||||||
return torch.cat(result, dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
class NextDiT(nn.Module):
|
class NextDiT(nn.Module):
|
||||||
"""
|
"""
|
||||||
Diffusion model with a Transformer backbone.
|
Diffusion model with a Transformer backbone.
|
||||||
@ -481,7 +462,6 @@ class NextDiT(nn.Module):
|
|||||||
assert (dim // n_heads) == sum(axes_dims)
|
assert (dim // n_heads) == sum(axes_dims)
|
||||||
self.axes_dims = axes_dims
|
self.axes_dims = axes_dims
|
||||||
self.axes_lens = axes_lens
|
self.axes_lens = axes_lens
|
||||||
# self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
|
|
||||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
@ -609,7 +589,6 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||||
|
|
||||||
|
|
||||||
# def forward(self, x, t, cap_feats, cap_mask):
|
# def forward(self, x, t, cap_feats, cap_mask):
|
||||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||||
t = 1.0 - timesteps
|
t = 1.0 - timesteps
|
||||||
@ -638,37 +617,3 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
return -x
|
return -x
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def precompute_freqs_cis(
|
|
||||||
dim: List[int],
|
|
||||||
end: List[int],
|
|
||||||
theta: float = 10000.0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Precompute the frequency tensor for complex exponentials (cis) with
|
|
||||||
given dimensions.
|
|
||||||
|
|
||||||
This function calculates a frequency tensor with complex exponentials
|
|
||||||
using the given dimension 'dim' and the end index 'end'. The 'theta'
|
|
||||||
parameter scales the frequencies. The returned tensor contains complex
|
|
||||||
values in complex64 data type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (list): Dimension of the frequency tensor.
|
|
||||||
end (list): End index for precomputing frequencies.
|
|
||||||
theta (float, optional): Scaling factor for frequency computation.
|
|
||||||
Defaults to 10000.0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Precomputed frequency tensor with complex
|
|
||||||
exponentials.
|
|
||||||
"""
|
|
||||||
freqs_cis = []
|
|
||||||
for i, (d, e) in enumerate(zip(dim, end)):
|
|
||||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
|
||||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
|
||||||
freqs = torch.outer(timestep, freqs).float()
|
|
||||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
|
||||||
freqs_cis.append(freqs_cis_i)
|
|
||||||
|
|
||||||
return freqs_cis
|
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.13"
|
__version__ = "0.3.14"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.13"
|
version = "0.3.14"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,115 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from aiohttp import web
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
|
||||||
from api_server.services.file_service import FileService
|
|
||||||
from folder_paths import models_dir, user_directory, output_directory
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def internal_routes():
|
|
||||||
return InternalRoutes(None)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def aiohttp_client_factory(aiohttp_client, internal_routes):
|
|
||||||
async def _get_client():
|
|
||||||
app = internal_routes.get_app()
|
|
||||||
return await aiohttp_client(app)
|
|
||||||
return _get_client
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_files_valid_directory(aiohttp_client_factory, internal_routes):
|
|
||||||
mock_file_list = [
|
|
||||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
|
||||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
|
||||||
]
|
|
||||||
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
|
||||||
client = await aiohttp_client_factory()
|
|
||||||
resp = await client.get('/files?directory=models')
|
|
||||||
assert resp.status == 200
|
|
||||||
data = await resp.json()
|
|
||||||
assert 'files' in data
|
|
||||||
assert len(data['files']) == 2
|
|
||||||
assert data['files'] == mock_file_list
|
|
||||||
|
|
||||||
# Check other valid directories
|
|
||||||
resp = await client.get('/files?directory=user')
|
|
||||||
assert resp.status == 200
|
|
||||||
resp = await client.get('/files?directory=output')
|
|
||||||
assert resp.status == 200
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_files_invalid_directory(aiohttp_client_factory, internal_routes):
|
|
||||||
internal_routes.file_service.list_files = MagicMock(side_effect=ValueError("Invalid directory key"))
|
|
||||||
client = await aiohttp_client_factory()
|
|
||||||
resp = await client.get('/files?directory=invalid')
|
|
||||||
assert resp.status == 400
|
|
||||||
data = await resp.json()
|
|
||||||
assert 'error' in data
|
|
||||||
assert data['error'] == "Invalid directory key"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_files_exception(aiohttp_client_factory, internal_routes):
|
|
||||||
internal_routes.file_service.list_files = MagicMock(side_effect=Exception("Unexpected error"))
|
|
||||||
client = await aiohttp_client_factory()
|
|
||||||
resp = await client.get('/files?directory=models')
|
|
||||||
assert resp.status == 500
|
|
||||||
data = await resp.json()
|
|
||||||
assert 'error' in data
|
|
||||||
assert data['error'] == "Unexpected error"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_files_no_directory_param(aiohttp_client_factory, internal_routes):
|
|
||||||
mock_file_list = []
|
|
||||||
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
|
||||||
client = await aiohttp_client_factory()
|
|
||||||
resp = await client.get('/files')
|
|
||||||
assert resp.status == 200
|
|
||||||
data = await resp.json()
|
|
||||||
assert 'files' in data
|
|
||||||
assert len(data['files']) == 0
|
|
||||||
|
|
||||||
def test_setup_routes(internal_routes):
|
|
||||||
internal_routes.setup_routes()
|
|
||||||
routes = internal_routes.routes
|
|
||||||
assert any(route.method == 'GET' and str(route.path) == '/files' for route in routes)
|
|
||||||
|
|
||||||
def test_get_app(internal_routes):
|
|
||||||
app = internal_routes.get_app()
|
|
||||||
assert isinstance(app, web.Application)
|
|
||||||
assert internal_routes._app is not None
|
|
||||||
|
|
||||||
def test_get_app_reuse(internal_routes):
|
|
||||||
app1 = internal_routes.get_app()
|
|
||||||
app2 = internal_routes.get_app()
|
|
||||||
assert app1 is app2
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
|
|
||||||
client = await aiohttp_client_factory()
|
|
||||||
try:
|
|
||||||
resp = await client.get('/files')
|
|
||||||
print(f"Response received: status {resp.status}") # noqa: T201
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception occurred during GET request: {e}") # noqa: T201
|
|
||||||
raise
|
|
||||||
|
|
||||||
assert resp.status != 404, "Route /files does not exist"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_file_service_initialization():
|
|
||||||
with patch('api_server.routes.internal.internal_routes.FileService') as MockFileService:
|
|
||||||
# Create a mock instance
|
|
||||||
mock_file_service_instance = MagicMock(spec=FileService)
|
|
||||||
MockFileService.return_value = mock_file_service_instance
|
|
||||||
internal_routes = InternalRoutes(None)
|
|
||||||
|
|
||||||
# Check if FileService was initialized with the correct parameters
|
|
||||||
MockFileService.assert_called_once_with({
|
|
||||||
"models": models_dir,
|
|
||||||
"user": user_directory,
|
|
||||||
"output": output_directory
|
|
||||||
})
|
|
||||||
|
|
||||||
# Verify that the file_service attribute of InternalRoutes is set
|
|
||||||
assert internal_routes.file_service == mock_file_service_instance
|
|
||||||
@ -1,54 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from api_server.services.file_service import FileService
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_file_system_ops():
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def file_service(mock_file_system_ops):
|
|
||||||
allowed_directories = {
|
|
||||||
"models": "/path/to/models",
|
|
||||||
"user": "/path/to/user",
|
|
||||||
"output": "/path/to/output"
|
|
||||||
}
|
|
||||||
return FileService(allowed_directories, file_system_ops=mock_file_system_ops)
|
|
||||||
|
|
||||||
def test_list_files_valid_directory(file_service, mock_file_system_ops):
|
|
||||||
mock_file_system_ops.walk_directory.return_value = [
|
|
||||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
|
||||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
|
||||||
]
|
|
||||||
|
|
||||||
result = file_service.list_files("models")
|
|
||||||
|
|
||||||
assert len(result) == 2
|
|
||||||
assert result[0]["name"] == "file1.txt"
|
|
||||||
assert result[1]["name"] == "dir1"
|
|
||||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
|
||||||
|
|
||||||
def test_list_files_invalid_directory(file_service):
|
|
||||||
# Does not support walking directories outside of the allowed directories
|
|
||||||
with pytest.raises(ValueError, match="Invalid directory key"):
|
|
||||||
file_service.list_files("invalid_key")
|
|
||||||
|
|
||||||
def test_list_files_empty_directory(file_service, mock_file_system_ops):
|
|
||||||
mock_file_system_ops.walk_directory.return_value = []
|
|
||||||
|
|
||||||
result = file_service.list_files("models")
|
|
||||||
|
|
||||||
assert len(result) == 0
|
|
||||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("directory_key", ["models", "user", "output"])
|
|
||||||
def test_list_files_all_allowed_directories(file_service, mock_file_system_ops, directory_key):
|
|
||||||
mock_file_system_ops.walk_directory.return_value = [
|
|
||||||
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
|
|
||||||
]
|
|
||||||
|
|
||||||
result = file_service.list_files(directory_key)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["name"] == f"file_{directory_key}.txt"
|
|
||||||
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
|
|
||||||
Loading…
Reference in New Issue
Block a user