full test suite wip

This commit is contained in:
doctorpangloss 2025-12-09 16:29:05 -08:00
parent 9c892a9b34
commit 7338873262
2 changed files with 10 additions and 13 deletions

View File

@ -86,14 +86,12 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
for k in sensitive: for k in sensitive:
extra_data[k] = sensitive[k] extra_data[k] = sensitive[k]
e.execute(item[2], prompt_id, extra_data, item[4]) # todo: ??? what jank
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
await e.execute_async(item[2], prompt_id, item[3], item[4]) await e.execute_async(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
# Extract error details from status_messages if there's an error
error_details = None error_details = None
if not e.success: if not e.success:
for event, data in e.status_messages: for event, data in e.status_messages:
@ -101,7 +99,6 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
error_details = data error_details = data
break break
# Convert status_messages tuples to string messages for backward compatibility
messages = [f"{event}: {data.get('exception_message', str(data))}" if isinstance(data, dict) and 'exception_message' in data else f"{event}" for event, data in e.status_messages] messages = [f"{event}: {data.get('exception_message', str(data))}" if isinstance(data, dict) and 'exception_message' in data else f"{event}" for event, data in e.status_messages]
q.task_done(item_id, q.task_done(item_id,
@ -150,8 +147,8 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
hook_breaker_ac10a0.restore_functions() hook_breaker_ac10a0.restore_functions()
def prompt_worker(q, server): def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
asyncio.run(_prompt_worker(q, server)) asyncio.run(_prompt_worker(q, server_instance))
async def run(server_instance, address='', port=8188, call_on_start=None): async def run(server_instance, address='', port=8188, call_on_start=None):

View File

@ -25,8 +25,8 @@ def _generate_config_params():
"use_pytorch_cross_attention", "use_pytorch_cross_attention",
# "use_split_cross_attention", # "use_split_cross_attention",
# "use_quad_cross_attention", # "use_quad_cross_attention",
"use_sage_attention", # "use_sage_attention",
"use_flash_attention" # "use_flash_attention"
] ]
attn_options = [ attn_options = [
{k: (k == target_key) for k in attn_keys} {k: (k == target_key) for k in attn_keys}
@ -35,17 +35,17 @@ def _generate_config_params():
async_options = [ async_options = [
{"disable_async_offload": False}, {"disable_async_offload": False},
{"disable_async_offload": True}, # {"disable_async_offload": True},
] ]
pinned_options = [ pinned_options = [
{"disable_pinned_memory": False}, {"disable_pinned_memory": False},
{"disable_pinned_memory": True}, # {"disable_pinned_memory": True},
] ]
fast_options = [ fast_options = [
{"fast": set()}, {"fast": set()},
{"fast": {PerformanceFeature.Fp16Accumulation}}, # {"fast": {PerformanceFeature.Fp16Accumulation}},
{"fast": {PerformanceFeature.Fp8MatrixMultiplication}}, # {"fast": {PerformanceFeature.Fp8MatrixMultiplication}},
{"fast": {PerformanceFeature.CublasOps}}, # {"fast": {PerformanceFeature.CublasOps}},
] ]
for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options): for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options):