mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Each folder in the listing now carries its effective display filter: the registered extension set, or the global supported_pt_extensions for match-all folders (registered with an empty set, e.g. LLM). Resolved per request so runtime registrations by custom nodes are reflected. Lets the asset-backed model sidebar filter match-all folder contents (hiding README/config noise) without hardcoding the global set client side, while registered folders stay governed by the scanner's extension-filtered model_type tags.
85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
import pytest
|
|
import base64
|
|
import json
|
|
import struct
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
from aiohttp import web
|
|
from unittest.mock import patch
|
|
from app.model_manager import ModelFileManager
|
|
|
|
pytestmark = (
|
|
pytest.mark.asyncio
|
|
) # This applies the asyncio mark to all test functions in the module
|
|
|
|
@pytest.fixture
|
|
def model_manager():
|
|
return ModelFileManager()
|
|
|
|
@pytest.fixture
|
|
def app(model_manager):
|
|
app = web.Application()
|
|
routes = web.RouteTableDef()
|
|
model_manager.add_routes(routes)
|
|
app.add_routes(routes)
|
|
return app
|
|
|
|
async def test_get_model_folders_includes_effective_extensions(aiohttp_client, app, tmp_path):
|
|
"""Folders expose their effective display filter: the registered extension
|
|
set, or the global supported_pt_extensions for match-all (empty) folders."""
|
|
with patch('folder_paths.folder_names_and_paths', {
|
|
'test_checkpoints': ([str(tmp_path)], {'.safetensors', '.ckpt'}),
|
|
'test_configs': ([str(tmp_path)], ['.yaml']),
|
|
'test_match_all': ([str(tmp_path)], set()),
|
|
'configs': ([str(tmp_path)], ['.yaml']),
|
|
}), patch('folder_paths.supported_pt_extensions', {'.safetensors', '.bin'}):
|
|
client = await aiohttp_client(app)
|
|
response = await client.get('/experiment/models')
|
|
|
|
assert response.status == 200
|
|
folders = {f['name']: f for f in await response.json()}
|
|
|
|
assert 'configs' not in folders # blocklisted
|
|
assert folders['test_checkpoints']['folders'] == [str(tmp_path)]
|
|
assert folders['test_checkpoints']['extensions'] == ['.ckpt', '.safetensors']
|
|
assert folders['test_configs']['extensions'] == ['.yaml']
|
|
# Match-all folders substitute the live global set.
|
|
assert folders['test_match_all']['extensions'] == ['.bin', '.safetensors']
|
|
|
|
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
|
img = Image.new('RGB', (100, 100), 'white')
|
|
img_byte_arr = BytesIO()
|
|
img.save(img_byte_arr, format='PNG')
|
|
img_byte_arr.seek(0)
|
|
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
|
|
|
|
safetensors_file = tmp_path / "test_model.safetensors"
|
|
header_bytes = json.dumps({
|
|
"__metadata__": {
|
|
"ssmd_cover_images": json.dumps([img_b64])
|
|
}
|
|
}).encode('utf-8')
|
|
length_bytes = struct.pack('<Q', len(header_bytes))
|
|
with open(safetensors_file, 'wb') as f:
|
|
f.write(length_bytes)
|
|
f.write(header_bytes)
|
|
|
|
with patch('folder_paths.folder_names_and_paths', {
|
|
'test_folder': ([str(tmp_path)], None)
|
|
}):
|
|
client = await aiohttp_client(app)
|
|
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
|
|
|
|
# Verify response
|
|
assert response.status == 200
|
|
assert response.content_type == 'image/webp'
|
|
|
|
# Verify the response contains valid image data
|
|
img_bytes = BytesIO(await response.read())
|
|
img = Image.open(img_bytes)
|
|
assert img.format
|
|
assert img.format.lower() == 'webp'
|
|
|
|
# Clean up
|
|
img.close()
|