mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 03:23:00 +08:00
Merge branch 'master' into v3-improvements
This commit is contained in:
commit
fc93133a1b
@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
|
|||||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
return memory_required, minimum_memory_required
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
_prepare_sampling,
|
_prepare_sampling,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||||
real_model: BaseModel = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|||||||
291
comfy_execution/jobs.py
Normal file
291
comfy_execution/jobs.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
"""
|
||||||
|
Job utilities for the /api/jobs endpoint.
|
||||||
|
Provides normalization and helper functions for job status tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from comfy_api.internal import prune_dict
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus:
|
||||||
|
"""Job status constants."""
|
||||||
|
PENDING = 'pending'
|
||||||
|
IN_PROGRESS = 'in_progress'
|
||||||
|
COMPLETED = 'completed'
|
||||||
|
FAILED = 'failed'
|
||||||
|
|
||||||
|
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
|
||||||
|
|
||||||
|
|
||||||
|
# Media types that can be previewed in the frontend
|
||||||
|
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
|
||||||
|
|
||||||
|
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||||
|
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||||
|
"""Extract create_time and workflow_id from extra_data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (create_time, workflow_id)
|
||||||
|
"""
|
||||||
|
create_time = extra_data.get('create_time')
|
||||||
|
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
||||||
|
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
||||||
|
return create_time, workflow_id
|
||||||
|
|
||||||
|
|
||||||
|
def is_previewable(media_type: str, item: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an output item is previewable.
|
||||||
|
Matches frontend logic in ComfyUI_frontend/src/stores/queueStore.ts
|
||||||
|
Maintains backwards compatibility with existing logic.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. media_type is 'images', 'video', or 'audio'
|
||||||
|
2. format field starts with 'video/' or 'audio/'
|
||||||
|
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
|
||||||
|
"""
|
||||||
|
if media_type in PREVIEWABLE_MEDIA_TYPES:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check format field (MIME type).
|
||||||
|
# Maintains backwards compatibility with how custom node outputs are handled in the frontend.
|
||||||
|
fmt = item.get('format', '')
|
||||||
|
if fmt and (fmt.startswith('video/') or fmt.startswith('audio/')):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for 3D files by extension
|
||||||
|
filename = item.get('filename', '').lower()
|
||||||
|
if any(filename.endswith(ext) for ext in THREE_D_EXTENSIONS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_queue_item(item: tuple, status: str) -> dict:
|
||||||
|
"""Convert queue item tuple to unified job dict.
|
||||||
|
|
||||||
|
Expects item with sensitive data already removed (5 elements).
|
||||||
|
"""
|
||||||
|
priority, prompt_id, _, extra_data, _ = item
|
||||||
|
create_time, workflow_id = _extract_job_metadata(extra_data)
|
||||||
|
|
||||||
|
return prune_dict({
|
||||||
|
'id': prompt_id,
|
||||||
|
'status': status,
|
||||||
|
'priority': priority,
|
||||||
|
'create_time': create_time,
|
||||||
|
'outputs_count': 0,
|
||||||
|
'workflow_id': workflow_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: bool = False) -> dict:
|
||||||
|
"""Convert history item dict to unified job dict.
|
||||||
|
|
||||||
|
History items have sensitive data already removed (prompt tuple has 5 elements).
|
||||||
|
"""
|
||||||
|
prompt_tuple = history_item['prompt']
|
||||||
|
priority, _, prompt, extra_data, _ = prompt_tuple
|
||||||
|
create_time, workflow_id = _extract_job_metadata(extra_data)
|
||||||
|
|
||||||
|
status_info = history_item.get('status', {})
|
||||||
|
status_str = status_info.get('status_str') if status_info else None
|
||||||
|
if status_str == 'success':
|
||||||
|
status = JobStatus.COMPLETED
|
||||||
|
elif status_str == 'error':
|
||||||
|
status = JobStatus.FAILED
|
||||||
|
else:
|
||||||
|
status = JobStatus.COMPLETED
|
||||||
|
|
||||||
|
outputs = history_item.get('outputs', {})
|
||||||
|
outputs_count, preview_output = get_outputs_summary(outputs)
|
||||||
|
|
||||||
|
execution_error = None
|
||||||
|
execution_start_time = None
|
||||||
|
execution_end_time = None
|
||||||
|
if status_info:
|
||||||
|
messages = status_info.get('messages', [])
|
||||||
|
for entry in messages:
|
||||||
|
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
||||||
|
event_name, event_data = entry[0], entry[1]
|
||||||
|
if isinstance(event_data, dict):
|
||||||
|
if event_name == 'execution_start':
|
||||||
|
execution_start_time = event_data.get('timestamp')
|
||||||
|
elif event_name in ('execution_success', 'execution_error', 'execution_interrupted'):
|
||||||
|
execution_end_time = event_data.get('timestamp')
|
||||||
|
if event_name == 'execution_error':
|
||||||
|
execution_error = event_data
|
||||||
|
|
||||||
|
job = prune_dict({
|
||||||
|
'id': prompt_id,
|
||||||
|
'status': status,
|
||||||
|
'priority': priority,
|
||||||
|
'create_time': create_time,
|
||||||
|
'execution_start_time': execution_start_time,
|
||||||
|
'execution_end_time': execution_end_time,
|
||||||
|
'execution_error': execution_error,
|
||||||
|
'outputs_count': outputs_count,
|
||||||
|
'preview_output': preview_output,
|
||||||
|
'workflow_id': workflow_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
if include_outputs:
|
||||||
|
job['outputs'] = outputs
|
||||||
|
job['execution_status'] = status_info
|
||||||
|
job['workflow'] = {
|
||||||
|
'prompt': prompt,
|
||||||
|
'extra_data': extra_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
return job
|
||||||
|
|
||||||
|
|
||||||
|
def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
||||||
|
"""
|
||||||
|
Count outputs and find preview in a single pass.
|
||||||
|
Returns (outputs_count, preview_output).
|
||||||
|
|
||||||
|
Preview priority (matching frontend):
|
||||||
|
1. type="output" with previewable media
|
||||||
|
2. Any previewable media
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
preview_output = None
|
||||||
|
fallback_preview = None
|
||||||
|
|
||||||
|
for node_id, node_outputs in outputs.items():
|
||||||
|
if not isinstance(node_outputs, dict):
|
||||||
|
continue
|
||||||
|
for media_type, items in node_outputs.items():
|
||||||
|
# 'animated' is a boolean flag, not actual output items
|
||||||
|
if media_type == 'animated' or not isinstance(items, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if preview_output is None and is_previewable(media_type, item):
|
||||||
|
enriched = {
|
||||||
|
**item,
|
||||||
|
'nodeId': node_id,
|
||||||
|
'mediaType': media_type
|
||||||
|
}
|
||||||
|
if item.get('type') == 'output':
|
||||||
|
preview_output = enriched
|
||||||
|
elif fallback_preview is None:
|
||||||
|
fallback_preview = enriched
|
||||||
|
|
||||||
|
return count, preview_output or fallback_preview
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sorting(jobs: list[dict], sort_by: str, sort_order: str) -> list[dict]:
|
||||||
|
"""Sort jobs list by specified field and order."""
|
||||||
|
reverse = (sort_order == 'desc')
|
||||||
|
|
||||||
|
if sort_by == 'execution_duration':
|
||||||
|
def get_sort_key(job):
|
||||||
|
start = job.get('execution_start_time', 0)
|
||||||
|
end = job.get('execution_end_time', 0)
|
||||||
|
return end - start if end and start else 0
|
||||||
|
else:
|
||||||
|
def get_sort_key(job):
|
||||||
|
return job.get('create_time', 0)
|
||||||
|
|
||||||
|
return sorted(jobs, key=get_sort_key, reverse=reverse)
|
||||||
|
|
||||||
|
|
||||||
|
def get_job(prompt_id: str, running: list, queued: list, history: dict) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get a single job by prompt_id from history or queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id: The prompt ID to look up
|
||||||
|
running: List of currently running queue items
|
||||||
|
queued: List of pending queue items
|
||||||
|
history: Dict of history items keyed by prompt_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Job dict with full details, or None if not found
|
||||||
|
"""
|
||||||
|
if prompt_id in history:
|
||||||
|
return normalize_history_item(prompt_id, history[prompt_id], include_outputs=True)
|
||||||
|
|
||||||
|
for item in running:
|
||||||
|
if item[1] == prompt_id:
|
||||||
|
return normalize_queue_item(item, JobStatus.IN_PROGRESS)
|
||||||
|
|
||||||
|
for item in queued:
|
||||||
|
if item[1] == prompt_id:
|
||||||
|
return normalize_queue_item(item, JobStatus.PENDING)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_jobs(
|
||||||
|
running: list,
|
||||||
|
queued: list,
|
||||||
|
history: dict,
|
||||||
|
status_filter: Optional[list[str]] = None,
|
||||||
|
workflow_id: Optional[str] = None,
|
||||||
|
sort_by: str = "created_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0
|
||||||
|
) -> tuple[list[dict], int]:
|
||||||
|
"""
|
||||||
|
Get all jobs (running, pending, completed) with filtering and sorting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
running: List of currently running queue items
|
||||||
|
queued: List of pending queue items
|
||||||
|
history: Dict of history items keyed by prompt_id
|
||||||
|
status_filter: List of statuses to include (from JobStatus.ALL)
|
||||||
|
workflow_id: Filter by workflow ID
|
||||||
|
sort_by: Field to sort by ('created_at', 'execution_duration')
|
||||||
|
sort_order: 'asc' or 'desc'
|
||||||
|
limit: Maximum number of items to return
|
||||||
|
offset: Number of items to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (jobs_list, total_count)
|
||||||
|
"""
|
||||||
|
jobs = []
|
||||||
|
|
||||||
|
if status_filter is None:
|
||||||
|
status_filter = JobStatus.ALL
|
||||||
|
|
||||||
|
if JobStatus.IN_PROGRESS in status_filter:
|
||||||
|
for item in running:
|
||||||
|
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
|
||||||
|
|
||||||
|
if JobStatus.PENDING in status_filter:
|
||||||
|
for item in queued:
|
||||||
|
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
|
||||||
|
|
||||||
|
include_completed = JobStatus.COMPLETED in status_filter
|
||||||
|
include_failed = JobStatus.FAILED in status_filter
|
||||||
|
if include_completed or include_failed:
|
||||||
|
for prompt_id, history_item in history.items():
|
||||||
|
is_failed = history_item.get('status', {}).get('status_str') == 'error'
|
||||||
|
if (is_failed and include_failed) or (not is_failed and include_completed):
|
||||||
|
jobs.append(normalize_history_item(prompt_id, history_item))
|
||||||
|
|
||||||
|
if workflow_id:
|
||||||
|
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
|
||||||
|
|
||||||
|
jobs = apply_sorting(jobs, sort_by, sort_order)
|
||||||
|
|
||||||
|
total_count = len(jobs)
|
||||||
|
|
||||||
|
if offset > 0:
|
||||||
|
jobs = jobs[offset:]
|
||||||
|
if limit is not None:
|
||||||
|
jobs = jobs[:limit]
|
||||||
|
|
||||||
|
return (jobs, total_count)
|
||||||
@ -1125,6 +1125,99 @@ class MergeTextListsNode(TextProcessingNode):
|
|||||||
# ========== Training Dataset Nodes ==========
|
# ========== Training Dataset Nodes ==========
|
||||||
|
|
||||||
|
|
||||||
|
class ResolutionBucket(io.ComfyNode):
|
||||||
|
"""Bucket latents and conditions by resolution for efficient batch training."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ResolutionBucket",
|
||||||
|
display_name="Resolution Bucket",
|
||||||
|
category="dataset",
|
||||||
|
is_experimental=True,
|
||||||
|
is_input_list=True,
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input(
|
||||||
|
"latents",
|
||||||
|
tooltip="List of latent dicts to bucket by resolution.",
|
||||||
|
),
|
||||||
|
io.Conditioning.Input(
|
||||||
|
"conditioning",
|
||||||
|
tooltip="List of conditioning lists (must match latents length).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(
|
||||||
|
display_name="latents",
|
||||||
|
is_output_list=True,
|
||||||
|
tooltip="List of batched latent dicts, one per resolution bucket.",
|
||||||
|
),
|
||||||
|
io.Conditioning.Output(
|
||||||
|
display_name="conditioning",
|
||||||
|
is_output_list=True,
|
||||||
|
tooltip="List of condition lists, one per resolution bucket.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, latents, conditioning):
|
||||||
|
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
|
||||||
|
# conditioning: list[list[cond]]
|
||||||
|
|
||||||
|
# Validate lengths match
|
||||||
|
if len(latents) != len(conditioning):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flatten latents and conditions to individual samples
|
||||||
|
flat_latents = [] # list of (C, H, W) tensors
|
||||||
|
flat_conditions = [] # list of condition lists
|
||||||
|
|
||||||
|
for latent_dict, cond in zip(latents, conditioning):
|
||||||
|
samples = latent_dict["samples"] # (B, C, H, W)
|
||||||
|
batch_size = samples.shape[0]
|
||||||
|
|
||||||
|
# cond is a list of conditions with length == batch_size
|
||||||
|
for i in range(batch_size):
|
||||||
|
flat_latents.append(samples[i]) # (C, H, W)
|
||||||
|
flat_conditions.append(cond[i]) # single condition
|
||||||
|
|
||||||
|
# Group by resolution (H, W)
|
||||||
|
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
||||||
|
|
||||||
|
for latent, cond in zip(flat_latents, flat_conditions):
|
||||||
|
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
|
||||||
|
h, w = latent.shape[-2], latent.shape[-1]
|
||||||
|
key = (h, w)
|
||||||
|
|
||||||
|
if key not in buckets:
|
||||||
|
buckets[key] = {"latents": [], "conditions": []}
|
||||||
|
|
||||||
|
buckets[key]["latents"].append(latent)
|
||||||
|
buckets[key]["conditions"].append(cond)
|
||||||
|
|
||||||
|
# Convert buckets to output format
|
||||||
|
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
|
||||||
|
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
||||||
|
|
||||||
|
for (h, w), bucket_data in buckets.items():
|
||||||
|
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
|
||||||
|
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
||||||
|
output_latents.append({"samples": stacked_latents})
|
||||||
|
|
||||||
|
# Conditions stay as list of condition lists
|
||||||
|
output_conditions.append(bucket_data["conditions"])
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
|
||||||
|
return io.NodeOutput(output_latents, output_conditions)
|
||||||
|
|
||||||
|
|
||||||
class MakeTrainingDataset(io.ComfyNode):
|
class MakeTrainingDataset(io.ComfyNode):
|
||||||
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||||
|
|
||||||
@ -1373,7 +1466,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
|||||||
shard_path = os.path.join(dataset_dir, shard_file)
|
shard_path = os.path.join(dataset_dir, shard_file)
|
||||||
|
|
||||||
with open(shard_path, "rb") as f:
|
with open(shard_path, "rb") as f:
|
||||||
shard_data = torch.load(f, weights_only=True)
|
shard_data = torch.load(f)
|
||||||
|
|
||||||
all_latents.extend(shard_data["latents"])
|
all_latents.extend(shard_data["latents"])
|
||||||
all_conditioning.extend(shard_data["conditioning"])
|
all_conditioning.extend(shard_data["conditioning"])
|
||||||
@ -1425,6 +1518,7 @@ class DatasetExtension(ComfyExtension):
|
|||||||
MakeTrainingDataset,
|
MakeTrainingDataset,
|
||||||
SaveTrainingDataset,
|
SaveTrainingDataset,
|
||||||
LoadTrainingDataset,
|
LoadTrainingDataset,
|
||||||
|
ResolutionBucket,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -224,6 +224,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
||||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||||
|
io.Int.Input("resolution_steps", default=1, min=1, max=256),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(),
|
io.Image.Output(),
|
||||||
@ -231,15 +232,15 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
|
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
|
||||||
samples = image.movedim(-1,1)
|
samples = image.movedim(-1,1)
|
||||||
total = int(megapixels * 1024 * 1024)
|
total = megapixels * 1024 * 1024
|
||||||
|
|
||||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||||
width = round(samples.shape[3] * scale_by)
|
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
|
||||||
height = round(samples.shape[2] * scale_by)
|
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
|
||||||
|
|
||||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
|
||||||
s = s.movedim(1,-1)
|
s = s.movedim(1,-1)
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
|
import comfy.sampler_helpers
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui
|
|||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
|
|
||||||
|
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||||
|
"""
|
||||||
|
CFGGuider with modifications for training specific logic
|
||||||
|
"""
|
||||||
|
def outer_sample(
|
||||||
|
self,
|
||||||
|
noise,
|
||||||
|
latent_image,
|
||||||
|
sampler,
|
||||||
|
sigmas,
|
||||||
|
denoise_mask=None,
|
||||||
|
callback=None,
|
||||||
|
disable_pbar=False,
|
||||||
|
seed=None,
|
||||||
|
latent_shapes=None,
|
||||||
|
):
|
||||||
|
self.inner_model, self.conds, self.loaded_models = (
|
||||||
|
comfy.sampler_helpers.prepare_sampling(
|
||||||
|
self.model_patcher,
|
||||||
|
noise.shape,
|
||||||
|
self.conds,
|
||||||
|
self.model_options,
|
||||||
|
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
denoise_mask = comfy.sampler_helpers.prepare_mask(
|
||||||
|
denoise_mask, noise.shape, device
|
||||||
|
)
|
||||||
|
|
||||||
|
noise = noise.to(device)
|
||||||
|
latent_image = latent_image.to(device)
|
||||||
|
sigmas = sigmas.to(device)
|
||||||
|
comfy.samplers.cast_to_load_options(
|
||||||
|
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.model_patcher.pre_run()
|
||||||
|
output = self.inner_sample(
|
||||||
|
noise,
|
||||||
|
latent_image,
|
||||||
|
device,
|
||||||
|
sampler,
|
||||||
|
sigmas,
|
||||||
|
denoise_mask,
|
||||||
|
callback,
|
||||||
|
disable_pbar,
|
||||||
|
seed,
|
||||||
|
latent_shapes=latent_shapes,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self.model_patcher.cleanup()
|
||||||
|
|
||||||
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||||
|
del self.inner_model
|
||||||
|
del self.loaded_models
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
@ -65,6 +128,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
seed=0,
|
seed=0,
|
||||||
training_dtype=torch.bfloat16,
|
training_dtype=torch.bfloat16,
|
||||||
real_dataset=None,
|
real_dataset=None,
|
||||||
|
bucket_latents=None,
|
||||||
):
|
):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -75,6 +139,28 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
||||||
|
# Bucket mode data
|
||||||
|
self.bucket_latents: list[torch.Tensor] | None = (
|
||||||
|
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||||
|
)
|
||||||
|
# Precompute bucket offsets and weights for sampling
|
||||||
|
if bucket_latents is not None:
|
||||||
|
self._init_bucket_data(bucket_latents)
|
||||||
|
else:
|
||||||
|
self.bucket_offsets = None
|
||||||
|
self.bucket_weights = None
|
||||||
|
self.num_images = None
|
||||||
|
|
||||||
|
def _init_bucket_data(self, bucket_latents):
|
||||||
|
"""Initialize bucket offsets and weights for sampling."""
|
||||||
|
self.bucket_offsets = [0]
|
||||||
|
bucket_sizes = []
|
||||||
|
for lat in bucket_latents:
|
||||||
|
bucket_sizes.append(lat.shape[0])
|
||||||
|
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
|
||||||
|
self.num_images = self.bucket_offsets[-1]
|
||||||
|
# Weights for sampling buckets proportional to their size
|
||||||
|
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
||||||
|
|
||||||
def fwd_bwd(
|
def fwd_bwd(
|
||||||
self,
|
self,
|
||||||
@ -115,6 +201,108 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
bwd_loss.backward()
|
bwd_loss.backward()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||||
|
"""Generate random sigma values for a batch."""
|
||||||
|
batch_sigmas = [
|
||||||
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
)
|
||||||
|
for _ in range(batch_size)
|
||||||
|
]
|
||||||
|
return torch.tensor(batch_sigmas).to(device)
|
||||||
|
|
||||||
|
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
|
||||||
|
"""Execute one training step in bucket mode."""
|
||||||
|
# Sample bucket (weighted by size), then sample batch from bucket
|
||||||
|
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
|
||||||
|
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
|
||||||
|
bucket_size = bucket_latent.shape[0]
|
||||||
|
bucket_offset = self.bucket_offsets[bucket_idx]
|
||||||
|
|
||||||
|
# Sample indices from this bucket (use all if bucket_size < batch_size)
|
||||||
|
actual_batch_size = min(self.batch_size, bucket_size)
|
||||||
|
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
|
||||||
|
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
||||||
|
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
||||||
|
|
||||||
|
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
||||||
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||||
|
batch_latent.device
|
||||||
|
)
|
||||||
|
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
|
||||||
|
|
||||||
|
loss = self.fwd_bwd(
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
cond, # Use flattened cond with absolute indices
|
||||||
|
absolute_indices,
|
||||||
|
extra_args,
|
||||||
|
self.num_images,
|
||||||
|
bwd=True,
|
||||||
|
)
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
|
||||||
|
|
||||||
|
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
||||||
|
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
|
||||||
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||||
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||||
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||||
|
batch_latent.device
|
||||||
|
)
|
||||||
|
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
|
||||||
|
|
||||||
|
loss = self.fwd_bwd(
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
cond,
|
||||||
|
indicies,
|
||||||
|
extra_args,
|
||||||
|
dataset_size,
|
||||||
|
bwd=True,
|
||||||
|
)
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
||||||
|
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
|
||||||
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||||
|
total_loss = 0
|
||||||
|
for index in indicies:
|
||||||
|
single_latent = self.real_dataset[index].to(latent_image)
|
||||||
|
batch_noise = noisegen.generate_noise(
|
||||||
|
{"samples": single_latent}
|
||||||
|
).to(single_latent.device)
|
||||||
|
batch_sigmas = (
|
||||||
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||||
|
loss = self.fwd_bwd(
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
single_latent,
|
||||||
|
cond,
|
||||||
|
[index],
|
||||||
|
extra_args,
|
||||||
|
dataset_size,
|
||||||
|
bwd=False,
|
||||||
|
)
|
||||||
|
total_loss += loss
|
||||||
|
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||||
|
total_loss.backward()
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(total_loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
model_wrap,
|
model_wrap,
|
||||||
@ -142,70 +330,18 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
||||||
self.seed + i * 1000
|
self.seed + i * 1000
|
||||||
)
|
)
|
||||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
|
||||||
|
|
||||||
if self.real_dataset is None:
|
if self.bucket_latents is not None:
|
||||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
|
||||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
elif self.real_dataset is None:
|
||||||
batch_latent.device
|
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||||
)
|
|
||||||
batch_sigmas = [
|
|
||||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
||||||
torch.rand((1,)).item()
|
|
||||||
)
|
|
||||||
for _ in range(min(self.batch_size, dataset_size))
|
|
||||||
]
|
|
||||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
|
||||||
|
|
||||||
loss = self.fwd_bwd(
|
|
||||||
model_wrap,
|
|
||||||
batch_sigmas,
|
|
||||||
batch_noise,
|
|
||||||
batch_latent,
|
|
||||||
cond,
|
|
||||||
indicies,
|
|
||||||
extra_args,
|
|
||||||
dataset_size,
|
|
||||||
bwd=True,
|
|
||||||
)
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(loss.item())
|
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
||||||
else:
|
else:
|
||||||
total_loss = 0
|
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||||
for index in indicies:
|
|
||||||
single_latent = self.real_dataset[index].to(latent_image)
|
|
||||||
batch_noise = noisegen.generate_noise(
|
|
||||||
{"samples": single_latent}
|
|
||||||
).to(single_latent.device)
|
|
||||||
batch_sigmas = (
|
|
||||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
||||||
torch.rand((1,)).item()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
|
||||||
loss = self.fwd_bwd(
|
|
||||||
model_wrap,
|
|
||||||
batch_sigmas,
|
|
||||||
batch_noise,
|
|
||||||
single_latent,
|
|
||||||
cond,
|
|
||||||
[index],
|
|
||||||
extra_args,
|
|
||||||
dataset_size,
|
|
||||||
bwd=False,
|
|
||||||
)
|
|
||||||
total_loss += loss
|
|
||||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
|
||||||
total_loss.backward()
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(total_loss.item())
|
|
||||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
|
||||||
|
|
||||||
if (i + 1) % self.grad_acc == 0:
|
if (i + 1) % self.grad_acc == 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
ui_pbar.update(1)
|
ui_pbar.update(1)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return torch.zeros_like(latent_image)
|
return torch.zeros_like(latent_image)
|
||||||
|
|
||||||
@ -283,6 +419,364 @@ def unpatch(m):
|
|||||||
del m.org_forward
|
del m.org_forward
|
||||||
|
|
||||||
|
|
||||||
|
def _process_latents_bucket_mode(latents):
|
||||||
|
"""Process latents for bucket mode training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of latent tensors
|
||||||
|
"""
|
||||||
|
bucket_latents = []
|
||||||
|
for latent_dict in latents:
|
||||||
|
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
||||||
|
return bucket_latents
|
||||||
|
|
||||||
|
|
||||||
|
def _process_latents_standard_mode(latents):
|
||||||
|
"""Process latents for standard (non-bucket) mode training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: list of latent dicts or single latent dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed latents (tensor or list of tensors)
|
||||||
|
"""
|
||||||
|
if len(latents) == 1:
|
||||||
|
return latents[0]["samples"] # Single latent dict
|
||||||
|
|
||||||
|
latent_list = []
|
||||||
|
for latent in latents:
|
||||||
|
latent = latent["samples"]
|
||||||
|
bs = latent.shape[0]
|
||||||
|
if bs != 1:
|
||||||
|
for sub_latent in latent:
|
||||||
|
latent_list.append(sub_latent[None])
|
||||||
|
else:
|
||||||
|
latent_list.append(latent)
|
||||||
|
return latent_list
|
||||||
|
|
||||||
|
|
||||||
|
def _process_conditioning(positive):
|
||||||
|
"""Process conditioning - either single list or list of lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positive: list of conditioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Flattened conditioning list
|
||||||
|
"""
|
||||||
|
if len(positive) == 1:
|
||||||
|
return positive[0] # Single conditioning list
|
||||||
|
|
||||||
|
# Multiple conditioning lists - flatten
|
||||||
|
flat_positive = []
|
||||||
|
for cond in positive:
|
||||||
|
if isinstance(cond, list):
|
||||||
|
flat_positive.extend(cond)
|
||||||
|
else:
|
||||||
|
flat_positive.append(cond)
|
||||||
|
return flat_positive
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
||||||
|
"""Convert latents to dtype and compute image counts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: Latents (tensor, list of tensors, or bucket list)
|
||||||
|
dtype: Target dtype
|
||||||
|
bucket_mode: Whether bucket mode is enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (processed_latents, num_images, multi_res)
|
||||||
|
"""
|
||||||
|
if bucket_mode:
|
||||||
|
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
||||||
|
latents = [t.to(dtype) for t in latents]
|
||||||
|
num_buckets = len(latents)
|
||||||
|
num_images = sum(t.shape[0] for t in latents)
|
||||||
|
multi_res = False # Not using multi_res path in bucket mode
|
||||||
|
|
||||||
|
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||||
|
for i, lat in enumerate(latents):
|
||||||
|
logging.info(f" Bucket {i}: shape {lat.shape}")
|
||||||
|
return latents, num_images, multi_res
|
||||||
|
|
||||||
|
# Non-bucket mode
|
||||||
|
if isinstance(latents, list):
|
||||||
|
all_shapes = set()
|
||||||
|
latents = [t.to(dtype) for t in latents]
|
||||||
|
for latent in latents:
|
||||||
|
all_shapes.add(latent.shape)
|
||||||
|
logging.info(f"Latent shapes: {all_shapes}")
|
||||||
|
if len(all_shapes) > 1:
|
||||||
|
multi_res = True
|
||||||
|
else:
|
||||||
|
multi_res = False
|
||||||
|
latents = torch.cat(latents, dim=0)
|
||||||
|
num_images = len(latents)
|
||||||
|
elif isinstance(latents, torch.Tensor):
|
||||||
|
latents = latents.to(dtype)
|
||||||
|
num_images = latents.shape[0]
|
||||||
|
multi_res = False
|
||||||
|
else:
|
||||||
|
logging.error(f"Invalid latents type: {type(latents)}")
|
||||||
|
num_images = 0
|
||||||
|
multi_res = False
|
||||||
|
|
||||||
|
return latents, num_images, multi_res
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
|
||||||
|
"""Validate conditioning count matches image count, expand if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positive: Conditioning list
|
||||||
|
num_images: Number of images
|
||||||
|
bucket_mode: Whether bucket mode is enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated/expanded conditioning list
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If conditioning count doesn't match image count
|
||||||
|
"""
|
||||||
|
if bucket_mode:
|
||||||
|
return positive # Skip validation in bucket mode
|
||||||
|
|
||||||
|
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||||
|
if len(positive) == 1 and num_images > 1:
|
||||||
|
return positive * num_images
|
||||||
|
elif len(positive) != num_images:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||||
|
)
|
||||||
|
return positive
|
||||||
|
|
||||||
|
|
||||||
|
def _load_existing_lora(existing_lora):
|
||||||
|
"""Load existing LoRA weights if provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
existing_lora: LoRA filename or "[None]"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (existing_weights dict, existing_steps int)
|
||||||
|
"""
|
||||||
|
if existing_lora == "[None]":
|
||||||
|
return {}, 0
|
||||||
|
|
||||||
|
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||||
|
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||||
|
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||||
|
existing_weights = {}
|
||||||
|
if lora_path:
|
||||||
|
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||||
|
return existing_weights, existing_steps
|
||||||
|
|
||||||
|
|
||||||
|
def _create_weight_adapter(
|
||||||
|
module, module_name, existing_weights, algorithm, lora_dtype, rank
|
||||||
|
):
|
||||||
|
"""Create a weight adapter for a module with weight.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module to create adapter for
|
||||||
|
module_name: Name of the module
|
||||||
|
existing_weights: Dict of existing LoRA weights
|
||||||
|
algorithm: Algorithm name for new adapters
|
||||||
|
lora_dtype: dtype for LoRA weights
|
||||||
|
rank: Rank for new LoRA adapters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (train_adapter, lora_params dict)
|
||||||
|
"""
|
||||||
|
key = f"{module_name}.weight"
|
||||||
|
shape = module.weight.shape
|
||||||
|
lora_params = {}
|
||||||
|
|
||||||
|
if len(shape) >= 2:
|
||||||
|
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||||
|
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||||
|
|
||||||
|
# Try to load existing adapter
|
||||||
|
existing_adapter = None
|
||||||
|
for adapter_cls in adapters:
|
||||||
|
existing_adapter = adapter_cls.load(
|
||||||
|
module_name, existing_weights, alpha, dora_scale
|
||||||
|
)
|
||||||
|
if existing_adapter is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if existing_adapter is None:
|
||||||
|
adapter_cls = adapter_maps[algorithm]
|
||||||
|
|
||||||
|
if existing_adapter is not None:
|
||||||
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||||
|
else:
|
||||||
|
# Use LoRA with alpha=1.0 by default
|
||||||
|
train_adapter = adapter_cls.create_train(
|
||||||
|
module.weight, rank=rank, alpha=1.0
|
||||||
|
).to(lora_dtype)
|
||||||
|
|
||||||
|
for name, parameter in train_adapter.named_parameters():
|
||||||
|
lora_params[f"{module_name}.{name}"] = parameter
|
||||||
|
|
||||||
|
return train_adapter.train().requires_grad_(True), lora_params
|
||||||
|
else:
|
||||||
|
# 1D weight - use BiasDiff
|
||||||
|
diff = torch.nn.Parameter(
|
||||||
|
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
|
||||||
|
)
|
||||||
|
diff_module = BiasDiff(diff).train().requires_grad_(True)
|
||||||
|
lora_params[f"{module_name}.diff"] = diff
|
||||||
|
return diff_module, lora_params
|
||||||
|
|
||||||
|
|
||||||
|
def _create_bias_adapter(module, module_name, lora_dtype):
|
||||||
|
"""Create a bias adapter for a module with bias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module with bias
|
||||||
|
module_name: Name of the module
|
||||||
|
lora_dtype: dtype for LoRA weights
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (bias_module, lora_params dict)
|
||||||
|
"""
|
||||||
|
bias = torch.nn.Parameter(
|
||||||
|
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||||
|
)
|
||||||
|
bias_module = BiasDiff(bias).train().requires_grad_(True)
|
||||||
|
lora_params = {f"{module_name}.diff_b": bias}
|
||||||
|
return bias_module, lora_params
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
|
||||||
|
"""Setup all LoRA adapters on the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mp: Model patcher
|
||||||
|
existing_weights: Dict of existing LoRA weights
|
||||||
|
algorithm: Algorithm name for new adapters
|
||||||
|
lora_dtype: dtype for LoRA weights
|
||||||
|
rank: Rank for new LoRA adapters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (lora_sd dict, all_weight_adapters list)
|
||||||
|
"""
|
||||||
|
lora_sd = {}
|
||||||
|
all_weight_adapters = []
|
||||||
|
|
||||||
|
for n, m in mp.model.named_modules():
|
||||||
|
if hasattr(m, "weight_function"):
|
||||||
|
if m.weight is not None:
|
||||||
|
adapter, params = _create_weight_adapter(
|
||||||
|
m, n, existing_weights, algorithm, lora_dtype, rank
|
||||||
|
)
|
||||||
|
lora_sd.update(params)
|
||||||
|
key = f"{n}.weight"
|
||||||
|
mp.add_weight_wrapper(key, adapter)
|
||||||
|
all_weight_adapters.append(adapter)
|
||||||
|
|
||||||
|
if hasattr(m, "bias") and m.bias is not None:
|
||||||
|
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
||||||
|
lora_sd.update(bias_params)
|
||||||
|
key = f"{n}.bias"
|
||||||
|
mp.add_weight_wrapper(key, bias_adapter)
|
||||||
|
all_weight_adapters.append(bias_adapter)
|
||||||
|
|
||||||
|
return lora_sd, all_weight_adapters
|
||||||
|
|
||||||
|
|
||||||
|
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
||||||
|
"""Create optimizer based on name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
|
||||||
|
parameters: Parameters to optimize
|
||||||
|
learning_rate: Learning rate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimizer instance
|
||||||
|
"""
|
||||||
|
if optimizer_name == "Adam":
|
||||||
|
return torch.optim.Adam(parameters, lr=learning_rate)
|
||||||
|
elif optimizer_name == "AdamW":
|
||||||
|
return torch.optim.AdamW(parameters, lr=learning_rate)
|
||||||
|
elif optimizer_name == "SGD":
|
||||||
|
return torch.optim.SGD(parameters, lr=learning_rate)
|
||||||
|
elif optimizer_name == "RMSprop":
|
||||||
|
return torch.optim.RMSprop(parameters, lr=learning_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_loss_function(loss_function_name):
|
||||||
|
"""Create loss function based on name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss function instance
|
||||||
|
"""
|
||||||
|
if loss_function_name == "MSE":
|
||||||
|
return torch.nn.MSELoss()
|
||||||
|
elif loss_function_name == "L1":
|
||||||
|
return torch.nn.L1Loss()
|
||||||
|
elif loss_function_name == "Huber":
|
||||||
|
return torch.nn.HuberLoss()
|
||||||
|
elif loss_function_name == "SmoothL1":
|
||||||
|
return torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_training_loop(
|
||||||
|
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
|
||||||
|
):
|
||||||
|
"""Execute the training loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guider: The guider object
|
||||||
|
train_sampler: The training sampler
|
||||||
|
latents: Latent tensors
|
||||||
|
num_images: Number of images
|
||||||
|
seed: Random seed
|
||||||
|
bucket_mode: Whether bucket mode is enabled
|
||||||
|
multi_res: Whether multi-resolution mode is enabled
|
||||||
|
"""
|
||||||
|
sigmas = torch.tensor(range(num_images))
|
||||||
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
|
|
||||||
|
if bucket_mode:
|
||||||
|
# Use first bucket's first latent as dummy for guider
|
||||||
|
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": dummy_latent}),
|
||||||
|
dummy_latent,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed,
|
||||||
|
)
|
||||||
|
elif multi_res:
|
||||||
|
# use first latent as dummy latent if multi_res
|
||||||
|
latents = latents[0].repeat(num_images, 1, 1, 1)
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": latents}),
|
||||||
|
latents,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": latents}),
|
||||||
|
latents,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainLoraNode(io.ComfyNode):
|
class TrainLoraNode(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -385,6 +879,11 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default="[None]",
|
default="[None]",
|
||||||
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"bucket_mode",
|
||||||
|
default=False,
|
||||||
|
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(
|
io.Model.Output(
|
||||||
@ -419,6 +918,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
|
bucket_mode,
|
||||||
):
|
):
|
||||||
# Extract scalars from lists (due to is_input_list=True)
|
# Extract scalars from lists (due to is_input_list=True)
|
||||||
model = model[0]
|
model = model[0]
|
||||||
@ -427,215 +927,125 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
grad_accumulation_steps = grad_accumulation_steps[0]
|
grad_accumulation_steps = grad_accumulation_steps[0]
|
||||||
learning_rate = learning_rate[0]
|
learning_rate = learning_rate[0]
|
||||||
rank = rank[0]
|
rank = rank[0]
|
||||||
optimizer = optimizer[0]
|
optimizer_name = optimizer[0]
|
||||||
loss_function = loss_function[0]
|
loss_function_name = loss_function[0]
|
||||||
seed = seed[0]
|
seed = seed[0]
|
||||||
training_dtype = training_dtype[0]
|
training_dtype = training_dtype[0]
|
||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
|
bucket_mode = bucket_mode[0]
|
||||||
|
|
||||||
# Handle latents - either single dict or list of dicts
|
# Process latents based on mode
|
||||||
if len(latents) == 1:
|
if bucket_mode:
|
||||||
latents = latents[0]["samples"] # Single latent dict
|
latents = _process_latents_bucket_mode(latents)
|
||||||
else:
|
else:
|
||||||
latent_list = []
|
latents = _process_latents_standard_mode(latents)
|
||||||
for latent in latents:
|
|
||||||
latent = latent["samples"]
|
|
||||||
bs = latent.shape[0]
|
|
||||||
if bs != 1:
|
|
||||||
for sub_latent in latent:
|
|
||||||
latent_list.append(sub_latent[None])
|
|
||||||
else:
|
|
||||||
latent_list.append(latent)
|
|
||||||
latents = latent_list
|
|
||||||
|
|
||||||
# Handle conditioning - either single list or list of lists
|
# Process conditioning
|
||||||
if len(positive) == 1:
|
positive = _process_conditioning(positive)
|
||||||
positive = positive[0] # Single conditioning list
|
|
||||||
else:
|
|
||||||
# Multiple conditioning lists - flatten
|
|
||||||
flat_positive = []
|
|
||||||
for cond in positive:
|
|
||||||
if isinstance(cond, list):
|
|
||||||
flat_positive.extend(cond)
|
|
||||||
else:
|
|
||||||
flat_positive.append(cond)
|
|
||||||
positive = flat_positive
|
|
||||||
|
|
||||||
|
# Setup model and dtype
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
# latents here can be list of different size latent or one large batch
|
# Prepare latents and compute counts
|
||||||
if isinstance(latents, list):
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||||
all_shapes = set()
|
latents, dtype, bucket_mode
|
||||||
latents = [t.to(dtype) for t in latents]
|
)
|
||||||
for latent in latents:
|
|
||||||
all_shapes.add(latent.shape)
|
|
||||||
logging.info(f"Latent shapes: {all_shapes}")
|
|
||||||
if len(all_shapes) > 1:
|
|
||||||
multi_res = True
|
|
||||||
else:
|
|
||||||
multi_res = False
|
|
||||||
latents = torch.cat(latents, dim=0)
|
|
||||||
num_images = len(latents)
|
|
||||||
elif isinstance(latents, torch.Tensor):
|
|
||||||
latents = latents.to(dtype)
|
|
||||||
num_images = latents.shape[0]
|
|
||||||
else:
|
|
||||||
logging.error(f"Invalid latents type: {type(latents)}")
|
|
||||||
|
|
||||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
# Validate and expand conditioning
|
||||||
if len(positive) == 1 and num_images > 1:
|
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||||
positive = positive * num_images
|
|
||||||
elif len(positive) != num_images:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
with torch.inference_mode(False):
|
||||||
lora_sd = {}
|
# Setup models for training
|
||||||
generator = torch.Generator()
|
mp.model.requires_grad_(False)
|
||||||
generator.manual_seed(seed)
|
|
||||||
|
|
||||||
# Load existing LoRA weights if provided
|
# Load existing LoRA weights if provided
|
||||||
existing_weights = {}
|
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||||
existing_steps = 0
|
|
||||||
if existing_lora != "[None]":
|
|
||||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
|
||||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
|
||||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
|
||||||
if lora_path:
|
|
||||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
|
||||||
|
|
||||||
all_weight_adapters = []
|
# Setup LoRA adapters
|
||||||
for n, m in mp.model.named_modules():
|
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
||||||
if hasattr(m, "weight_function"):
|
mp, existing_weights, algorithm, lora_dtype, rank
|
||||||
if m.weight is not None:
|
)
|
||||||
key = "{}.weight".format(n)
|
|
||||||
shape = m.weight.shape
|
|
||||||
if len(shape) >= 2:
|
|
||||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
|
||||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
|
||||||
for adapter_cls in adapters:
|
|
||||||
existing_adapter = adapter_cls.load(
|
|
||||||
n, existing_weights, alpha, dora_scale
|
|
||||||
)
|
|
||||||
if existing_adapter is not None:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
existing_adapter = None
|
|
||||||
adapter_cls = adapter_maps[algorithm]
|
|
||||||
|
|
||||||
if existing_adapter is not None:
|
# Create optimizer and loss function
|
||||||
train_adapter = existing_adapter.to_train().to(
|
optimizer = _create_optimizer(
|
||||||
lora_dtype
|
optimizer_name, lora_sd.values(), learning_rate
|
||||||
)
|
)
|
||||||
else:
|
criterion = _create_loss_function(loss_function_name)
|
||||||
# Use LoRA with alpha=1.0 by default
|
|
||||||
train_adapter = adapter_cls.create_train(
|
|
||||||
m.weight, rank=rank, alpha=1.0
|
|
||||||
).to(lora_dtype)
|
|
||||||
for name, parameter in train_adapter.named_parameters():
|
|
||||||
lora_sd[f"{n}.{name}"] = parameter
|
|
||||||
|
|
||||||
mp.add_weight_wrapper(key, train_adapter)
|
# Setup gradient checkpointing
|
||||||
all_weight_adapters.append(train_adapter)
|
|
||||||
else:
|
|
||||||
diff = torch.nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
diff_module = BiasDiff(diff)
|
|
||||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
|
||||||
all_weight_adapters.append(diff_module)
|
|
||||||
lora_sd["{}.diff".format(n)] = diff
|
|
||||||
if hasattr(m, "bias") and m.bias is not None:
|
|
||||||
key = "{}.bias".format(n)
|
|
||||||
bias = torch.nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
m.bias.shape, dtype=lora_dtype, requires_grad=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
bias_module = BiasDiff(bias)
|
|
||||||
lora_sd["{}.diff_b".format(n)] = bias
|
|
||||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
|
||||||
all_weight_adapters.append(bias_module)
|
|
||||||
|
|
||||||
if optimizer == "Adam":
|
|
||||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
|
||||||
elif optimizer == "AdamW":
|
|
||||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
|
||||||
elif optimizer == "SGD":
|
|
||||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
|
||||||
elif optimizer == "RMSprop":
|
|
||||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
|
||||||
|
|
||||||
# Setup loss function based on selection
|
|
||||||
if loss_function == "MSE":
|
|
||||||
criterion = torch.nn.MSELoss()
|
|
||||||
elif loss_function == "L1":
|
|
||||||
criterion = torch.nn.L1Loss()
|
|
||||||
elif loss_function == "Huber":
|
|
||||||
criterion = torch.nn.HuberLoss()
|
|
||||||
elif loss_function == "SmoothL1":
|
|
||||||
criterion = torch.nn.SmoothL1Loss()
|
|
||||||
|
|
||||||
# setup models
|
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
for m in find_all_highest_child_module_with_forward(
|
for m in find_all_highest_child_module_with_forward(
|
||||||
mp.model.diffusion_model
|
mp.model.diffusion_model
|
||||||
):
|
):
|
||||||
patch(m)
|
patch(m)
|
||||||
mp.model.requires_grad_(False)
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# With force_full_load=False we should be able to have offloading
|
||||||
|
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[mp], memory_required=1e20, force_full_load=True
|
[mp], memory_required=1e20, force_full_load=True
|
||||||
)
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Setup sampler and guider like in test script
|
# Setup loss tracking
|
||||||
loss_map = {"loss": []}
|
loss_map = {"loss": []}
|
||||||
|
|
||||||
def loss_callback(loss):
|
def loss_callback(loss):
|
||||||
loss_map["loss"].append(loss)
|
loss_map["loss"].append(loss)
|
||||||
|
|
||||||
train_sampler = TrainSampler(
|
# Create sampler
|
||||||
criterion,
|
if bucket_mode:
|
||||||
optimizer,
|
train_sampler = TrainSampler(
|
||||||
loss_callback=loss_callback,
|
criterion,
|
||||||
batch_size=batch_size,
|
optimizer,
|
||||||
grad_acc=grad_accumulation_steps,
|
loss_callback=loss_callback,
|
||||||
total_steps=steps * grad_accumulation_steps,
|
batch_size=batch_size,
|
||||||
seed=seed,
|
grad_acc=grad_accumulation_steps,
|
||||||
training_dtype=dtype,
|
total_steps=steps * grad_accumulation_steps,
|
||||||
real_dataset=latents if multi_res else None,
|
seed=seed,
|
||||||
)
|
training_dtype=dtype,
|
||||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
bucket_latents=latents,
|
||||||
guider.set_conds(positive) # Set conditioning from input
|
)
|
||||||
|
else:
|
||||||
|
train_sampler = TrainSampler(
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
loss_callback=loss_callback,
|
||||||
|
batch_size=batch_size,
|
||||||
|
grad_acc=grad_accumulation_steps,
|
||||||
|
total_steps=steps * grad_accumulation_steps,
|
||||||
|
seed=seed,
|
||||||
|
training_dtype=dtype,
|
||||||
|
real_dataset=latents if multi_res else None,
|
||||||
|
)
|
||||||
|
|
||||||
# Training loop
|
# Setup guider
|
||||||
|
guider = TrainGuider(mp)
|
||||||
|
guider.set_conds(positive)
|
||||||
|
|
||||||
|
# Run training loop
|
||||||
try:
|
try:
|
||||||
# Generate dummy sigmas and noise
|
_run_training_loop(
|
||||||
sigmas = torch.tensor(range(num_images))
|
guider,
|
||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
|
||||||
if multi_res:
|
|
||||||
# use first latent as dummy latent if multi_res
|
|
||||||
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
|
|
||||||
guider.sample(
|
|
||||||
noise.generate_noise({"samples": latents}),
|
|
||||||
latents,
|
|
||||||
train_sampler,
|
train_sampler,
|
||||||
sigmas,
|
latents,
|
||||||
seed=noise.seed,
|
num_images,
|
||||||
|
seed,
|
||||||
|
bucket_mode,
|
||||||
|
multi_res,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
for m in mp.model.modules():
|
for m in mp.model.modules():
|
||||||
unpatch(m)
|
unpatch(m)
|
||||||
del train_sampler, optimizer
|
del train_sampler, optimizer
|
||||||
|
|
||||||
|
# Finalize adapters
|
||||||
for adapter in all_weight_adapters:
|
for adapter in all_weight_adapters:
|
||||||
adapter.requires_grad_(False)
|
adapter.requires_grad_(False)
|
||||||
|
|
||||||
@ -645,7 +1055,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
class LoraModelLoader(io.ComfyNode):
|
class LoraModelLoader(io.ComfyNode):#
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.5.0"
|
__version__ = "0.5.1"
|
||||||
|
|||||||
66
main.py
66
main.py
@ -23,6 +23,38 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
|
if os.name == "nt":
|
||||||
|
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
||||||
|
if args.default_device is not None:
|
||||||
|
default_dev = args.default_device
|
||||||
|
devices = list(range(32))
|
||||||
|
devices.remove(default_dev)
|
||||||
|
devices.insert(0, default_dev)
|
||||||
|
devices = ','.join(map(str, devices))
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
|
||||||
|
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
|
||||||
|
|
||||||
|
if args.cuda_device is not None:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
|
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
|
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
|
||||||
|
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||||
|
|
||||||
|
if args.oneapi_device_selector is not None:
|
||||||
|
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
||||||
|
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
||||||
|
|
||||||
|
if args.deterministic:
|
||||||
|
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||||
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||||
|
|
||||||
|
import cuda_malloc
|
||||||
|
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
||||||
|
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
||||||
|
|
||||||
|
|
||||||
def handle_comfyui_manager_unavailable():
|
def handle_comfyui_manager_unavailable():
|
||||||
if not args.windows_standalone_build:
|
if not args.windows_standalone_build:
|
||||||
@ -137,40 +169,6 @@ import shutil
|
|||||||
import threading
|
import threading
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
|
||||||
if os.name == "nt":
|
|
||||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
|
||||||
if args.default_device is not None:
|
|
||||||
default_dev = args.default_device
|
|
||||||
devices = list(range(32))
|
|
||||||
devices.remove(default_dev)
|
|
||||||
devices.insert(0, default_dev)
|
|
||||||
devices = ','.join(map(str, devices))
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
|
|
||||||
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
|
|
||||||
|
|
||||||
if args.cuda_device is not None:
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
|
||||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
|
||||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
|
|
||||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
|
||||||
|
|
||||||
if args.oneapi_device_selector is not None:
|
|
||||||
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
|
||||||
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
|
||||||
|
|
||||||
if args.deterministic:
|
|
||||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
|
||||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
|
||||||
|
|
||||||
import cuda_malloc
|
|
||||||
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
|
||||||
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
|
||||||
|
|
||||||
|
|
||||||
if 'torch' in sys.modules:
|
if 'torch' in sys.modules:
|
||||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.5.0"
|
version = "0.5.1"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
135
server.py
135
server.py
@ -7,6 +7,7 @@ import time
|
|||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import execution
|
import execution
|
||||||
|
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
||||||
import uuid
|
import uuid
|
||||||
import urllib
|
import urllib
|
||||||
import json
|
import json
|
||||||
@ -47,6 +48,12 @@ from middleware.cache_middleware import cache_control
|
|||||||
if args.enable_manager:
|
if args.enable_manager:
|
||||||
import comfyui_manager
|
import comfyui_manager
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_sensitive_from_queue(queue: list) -> list:
|
||||||
|
"""Remove sensitive data (index 5) from queue item tuples."""
|
||||||
|
return [item[:5] for item in queue]
|
||||||
|
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
await function(message)
|
await function(message)
|
||||||
@ -694,6 +701,129 @@ class PromptServer():
|
|||||||
out[node_class] = node_info(node_class)
|
out[node_class] = node_info(node_class)
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
|
@routes.get("/api/jobs")
|
||||||
|
async def get_jobs(request):
|
||||||
|
"""List all jobs with filtering, sorting, and pagination.
|
||||||
|
|
||||||
|
Query parameters:
|
||||||
|
status: Filter by status (comma-separated): pending, in_progress, completed, failed
|
||||||
|
workflow_id: Filter by workflow ID
|
||||||
|
sort_by: Sort field: created_at (default), execution_duration
|
||||||
|
sort_order: Sort direction: asc, desc (default)
|
||||||
|
limit: Max items to return (positive integer)
|
||||||
|
offset: Items to skip (non-negative integer, default 0)
|
||||||
|
"""
|
||||||
|
query = request.rel_url.query
|
||||||
|
|
||||||
|
status_param = query.get('status')
|
||||||
|
workflow_id = query.get('workflow_id')
|
||||||
|
sort_by = query.get('sort_by', 'created_at').lower()
|
||||||
|
sort_order = query.get('sort_order', 'desc').lower()
|
||||||
|
|
||||||
|
status_filter = None
|
||||||
|
if status_param:
|
||||||
|
status_filter = [s.strip().lower() for s in status_param.split(',') if s.strip()]
|
||||||
|
invalid_statuses = [s for s in status_filter if s not in JobStatus.ALL]
|
||||||
|
if invalid_statuses:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": f"Invalid status value(s): {', '.join(invalid_statuses)}. Valid values: {', '.join(JobStatus.ALL)}"},
|
||||||
|
status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
if sort_by not in {'created_at', 'execution_duration'}:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
|
||||||
|
status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
if sort_order not in {'asc', 'desc'}:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "sort_order must be 'asc' or 'desc'"},
|
||||||
|
status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
limit = None
|
||||||
|
|
||||||
|
# If limit is provided, validate that it is a positive integer, else continue without a limit
|
||||||
|
if 'limit' in query:
|
||||||
|
try:
|
||||||
|
limit = int(query.get('limit'))
|
||||||
|
if limit <= 0:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "limit must be a positive integer"},
|
||||||
|
status=400
|
||||||
|
)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "limit must be an integer"},
|
||||||
|
status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
if 'offset' in query:
|
||||||
|
try:
|
||||||
|
offset = int(query.get('offset'))
|
||||||
|
if offset < 0:
|
||||||
|
offset = 0
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "offset must be an integer"},
|
||||||
|
status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
running, queued = self.prompt_queue.get_current_queue_volatile()
|
||||||
|
history = self.prompt_queue.get_history()
|
||||||
|
|
||||||
|
running = _remove_sensitive_from_queue(running)
|
||||||
|
queued = _remove_sensitive_from_queue(queued)
|
||||||
|
|
||||||
|
jobs, total = get_all_jobs(
|
||||||
|
running, queued, history,
|
||||||
|
status_filter=status_filter,
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
|
||||||
|
has_more = (offset + len(jobs)) < total
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'jobs': jobs,
|
||||||
|
'pagination': {
|
||||||
|
'offset': offset,
|
||||||
|
'limit': limit,
|
||||||
|
'total': total,
|
||||||
|
'has_more': has_more
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
@routes.get("/api/jobs/{job_id}")
|
||||||
|
async def get_job_by_id(request):
|
||||||
|
"""Get a single job by ID."""
|
||||||
|
job_id = request.match_info.get("job_id", None)
|
||||||
|
if not job_id:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "job_id is required"},
|
||||||
|
status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
running, queued = self.prompt_queue.get_current_queue_volatile()
|
||||||
|
history = self.prompt_queue.get_history(prompt_id=job_id)
|
||||||
|
|
||||||
|
running = _remove_sensitive_from_queue(running)
|
||||||
|
queued = _remove_sensitive_from_queue(queued)
|
||||||
|
|
||||||
|
job = get_job(job_id, running, queued, history)
|
||||||
|
if job is None:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "Job not found"},
|
||||||
|
status=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response(job)
|
||||||
|
|
||||||
@routes.get("/history")
|
@routes.get("/history")
|
||||||
async def get_history(request):
|
async def get_history(request):
|
||||||
max_items = request.rel_url.query.get("max_items", None)
|
max_items = request.rel_url.query.get("max_items", None)
|
||||||
@ -717,9 +847,8 @@ class PromptServer():
|
|||||||
async def get_queue(request):
|
async def get_queue(request):
|
||||||
queue_info = {}
|
queue_info = {}
|
||||||
current_queue = self.prompt_queue.get_current_queue_volatile()
|
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||||
remove_sensitive = lambda queue: [x[:5] for x in queue]
|
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0])
|
||||||
queue_info['queue_running'] = remove_sensitive(current_queue[0])
|
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
|
||||||
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
|
|
||||||
return web.json_response(queue_info)
|
return web.json_response(queue_info)
|
||||||
|
|
||||||
@routes.post("/prompt")
|
@routes.post("/prompt")
|
||||||
|
|||||||
@ -99,6 +99,37 @@ class ComfyClient:
|
|||||||
with urllib.request.urlopen(url) as response:
|
with urllib.request.urlopen(url) as response:
|
||||||
return json.loads(response.read())
|
return json.loads(response.read())
|
||||||
|
|
||||||
|
def get_jobs(self, status=None, limit=None, offset=None, sort_by=None, sort_order=None):
|
||||||
|
url = "http://{}/api/jobs".format(self.server_address)
|
||||||
|
params = {}
|
||||||
|
if status is not None:
|
||||||
|
params["status"] = status
|
||||||
|
if limit is not None:
|
||||||
|
params["limit"] = limit
|
||||||
|
if offset is not None:
|
||||||
|
params["offset"] = offset
|
||||||
|
if sort_by is not None:
|
||||||
|
params["sort_by"] = sort_by
|
||||||
|
if sort_order is not None:
|
||||||
|
params["sort_order"] = sort_order
|
||||||
|
|
||||||
|
if params:
|
||||||
|
url_values = urllib.parse.urlencode(params)
|
||||||
|
url = "{}?{}".format(url, url_values)
|
||||||
|
|
||||||
|
with urllib.request.urlopen(url) as response:
|
||||||
|
return json.loads(response.read())
|
||||||
|
|
||||||
|
def get_job(self, job_id):
|
||||||
|
url = "http://{}/api/jobs/{}".format(self.server_address, job_id)
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(url) as response:
|
||||||
|
return json.loads(response.read())
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
if e.code == 404:
|
||||||
|
return None
|
||||||
|
raise
|
||||||
|
|
||||||
def set_test_name(self, name):
|
def set_test_name(self, name):
|
||||||
self.test_name = name
|
self.test_name = name
|
||||||
|
|
||||||
@ -877,3 +908,106 @@ class TestExecution:
|
|||||||
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
|
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
|
||||||
|
|
||||||
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
|
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
|
||||||
|
|
||||||
|
# Jobs API tests
|
||||||
|
def test_jobs_api_job_structure(
|
||||||
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test that job objects have required fields"""
|
||||||
|
self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
jobs_response = client.get_jobs(status="completed", limit=1)
|
||||||
|
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
|
||||||
|
|
||||||
|
job = jobs_response["jobs"][0]
|
||||||
|
assert "id" in job, "Job should have id"
|
||||||
|
assert "status" in job, "Job should have status"
|
||||||
|
assert "create_time" in job, "Job should have create_time"
|
||||||
|
assert "outputs_count" in job, "Job should have outputs_count"
|
||||||
|
assert "preview_output" in job, "Job should have preview_output"
|
||||||
|
|
||||||
|
def test_jobs_api_preview_output_structure(
|
||||||
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test that preview_output has correct structure"""
|
||||||
|
self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
jobs_response = client.get_jobs(status="completed", limit=1)
|
||||||
|
job = jobs_response["jobs"][0]
|
||||||
|
|
||||||
|
if job["preview_output"] is not None:
|
||||||
|
preview = job["preview_output"]
|
||||||
|
assert "filename" in preview, "Preview should have filename"
|
||||||
|
assert "nodeId" in preview, "Preview should have nodeId"
|
||||||
|
assert "mediaType" in preview, "Preview should have mediaType"
|
||||||
|
|
||||||
|
def test_jobs_api_pagination(
|
||||||
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test jobs API pagination"""
|
||||||
|
for _ in range(5):
|
||||||
|
self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
first_page = client.get_jobs(limit=2, offset=0)
|
||||||
|
second_page = client.get_jobs(limit=2, offset=2)
|
||||||
|
|
||||||
|
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
|
||||||
|
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
|
||||||
|
|
||||||
|
first_ids = {j["id"] for j in first_page["jobs"]}
|
||||||
|
second_ids = {j["id"] for j in second_page["jobs"]}
|
||||||
|
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
|
||||||
|
|
||||||
|
def test_jobs_api_sorting(
|
||||||
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test jobs API sorting"""
|
||||||
|
for _ in range(3):
|
||||||
|
self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
desc_jobs = client.get_jobs(sort_order="desc")
|
||||||
|
asc_jobs = client.get_jobs(sort_order="asc")
|
||||||
|
|
||||||
|
if len(desc_jobs["jobs"]) >= 2:
|
||||||
|
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
|
||||||
|
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
|
||||||
|
if len(desc_times) >= 2:
|
||||||
|
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
|
||||||
|
if len(asc_times) >= 2:
|
||||||
|
assert asc_times == sorted(asc_times), "Asc should be oldest first"
|
||||||
|
|
||||||
|
def test_jobs_api_status_filter(
|
||||||
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test jobs API status filtering"""
|
||||||
|
self._create_history_item(client, builder)
|
||||||
|
|
||||||
|
completed_jobs = client.get_jobs(status="completed")
|
||||||
|
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
|
||||||
|
|
||||||
|
for job in completed_jobs["jobs"]:
|
||||||
|
assert job["status"] == "completed", "Should only return completed jobs"
|
||||||
|
|
||||||
|
# Pending jobs are transient - just verify filter doesn't error
|
||||||
|
pending_jobs = client.get_jobs(status="pending")
|
||||||
|
for job in pending_jobs["jobs"]:
|
||||||
|
assert job["status"] == "pending", "Should only return pending jobs"
|
||||||
|
|
||||||
|
def test_get_job_by_id(
|
||||||
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test getting a single job by ID"""
|
||||||
|
result = self._create_history_item(client, builder)
|
||||||
|
prompt_id = result.get_prompt_id()
|
||||||
|
|
||||||
|
job = client.get_job(prompt_id)
|
||||||
|
assert job is not None, "Should find the job"
|
||||||
|
assert job["id"] == prompt_id, "Job ID should match"
|
||||||
|
assert "outputs" in job, "Single job should include outputs"
|
||||||
|
|
||||||
|
def test_get_job_not_found(
|
||||||
|
self, client: ComfyClient, builder: GraphBuilder
|
||||||
|
):
|
||||||
|
"""Test getting a non-existent job returns 404"""
|
||||||
|
job = client.get_job("nonexistent-job-id")
|
||||||
|
assert job is None, "Non-existent job should return None"
|
||||||
|
|||||||
361
tests/execution/test_jobs.py
Normal file
361
tests/execution/test_jobs.py
Normal file
@ -0,0 +1,361 @@
|
|||||||
|
"""Unit tests for comfy_execution/jobs.py"""
|
||||||
|
|
||||||
|
from comfy_execution.jobs import (
|
||||||
|
JobStatus,
|
||||||
|
is_previewable,
|
||||||
|
normalize_queue_item,
|
||||||
|
normalize_history_item,
|
||||||
|
get_outputs_summary,
|
||||||
|
apply_sorting,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJobStatus:
|
||||||
|
"""Test JobStatus constants."""
|
||||||
|
|
||||||
|
def test_status_values(self):
|
||||||
|
"""Status constants should have expected string values."""
|
||||||
|
assert JobStatus.PENDING == 'pending'
|
||||||
|
assert JobStatus.IN_PROGRESS == 'in_progress'
|
||||||
|
assert JobStatus.COMPLETED == 'completed'
|
||||||
|
assert JobStatus.FAILED == 'failed'
|
||||||
|
|
||||||
|
def test_all_contains_all_statuses(self):
|
||||||
|
"""ALL should contain all status values."""
|
||||||
|
assert JobStatus.PENDING in JobStatus.ALL
|
||||||
|
assert JobStatus.IN_PROGRESS in JobStatus.ALL
|
||||||
|
assert JobStatus.COMPLETED in JobStatus.ALL
|
||||||
|
assert JobStatus.FAILED in JobStatus.ALL
|
||||||
|
assert len(JobStatus.ALL) == 4
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsPreviewable:
|
||||||
|
"""Unit tests for is_previewable()"""
|
||||||
|
|
||||||
|
def test_previewable_media_types(self):
|
||||||
|
"""Images, video, audio media types should be previewable."""
|
||||||
|
for media_type in ['images', 'video', 'audio']:
|
||||||
|
assert is_previewable(media_type, {}) is True
|
||||||
|
|
||||||
|
def test_non_previewable_media_types(self):
|
||||||
|
"""Other media types should not be previewable."""
|
||||||
|
for media_type in ['latents', 'text', 'metadata', 'files']:
|
||||||
|
assert is_previewable(media_type, {}) is False
|
||||||
|
|
||||||
|
def test_3d_extensions_previewable(self):
|
||||||
|
"""3D file extensions should be previewable regardless of media_type."""
|
||||||
|
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||||
|
item = {'filename': f'model{ext}'}
|
||||||
|
assert is_previewable('files', item) is True
|
||||||
|
|
||||||
|
def test_3d_extensions_case_insensitive(self):
|
||||||
|
"""3D extension check should be case insensitive."""
|
||||||
|
item = {'filename': 'MODEL.GLB'}
|
||||||
|
assert is_previewable('files', item) is True
|
||||||
|
|
||||||
|
def test_video_format_previewable(self):
|
||||||
|
"""Items with video/ format should be previewable."""
|
||||||
|
item = {'format': 'video/mp4'}
|
||||||
|
assert is_previewable('files', item) is True
|
||||||
|
|
||||||
|
def test_audio_format_previewable(self):
|
||||||
|
"""Items with audio/ format should be previewable."""
|
||||||
|
item = {'format': 'audio/wav'}
|
||||||
|
assert is_previewable('files', item) is True
|
||||||
|
|
||||||
|
def test_other_format_not_previewable(self):
|
||||||
|
"""Items with other format should not be previewable."""
|
||||||
|
item = {'format': 'application/json'}
|
||||||
|
assert is_previewable('files', item) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetOutputsSummary:
|
||||||
|
"""Unit tests for get_outputs_summary()"""
|
||||||
|
|
||||||
|
def test_empty_outputs(self):
|
||||||
|
"""Empty outputs should return 0 count and None preview."""
|
||||||
|
count, preview = get_outputs_summary({})
|
||||||
|
assert count == 0
|
||||||
|
assert preview is None
|
||||||
|
|
||||||
|
def test_counts_across_multiple_nodes(self):
|
||||||
|
"""Outputs from multiple nodes should all be counted."""
|
||||||
|
outputs = {
|
||||||
|
'node1': {'images': [{'filename': 'a.png', 'type': 'output'}]},
|
||||||
|
'node2': {'images': [{'filename': 'b.png', 'type': 'output'}]},
|
||||||
|
'node3': {'images': [
|
||||||
|
{'filename': 'c.png', 'type': 'output'},
|
||||||
|
{'filename': 'd.png', 'type': 'output'}
|
||||||
|
]}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert count == 4
|
||||||
|
|
||||||
|
def test_skips_animated_key_and_non_list_values(self):
|
||||||
|
"""The 'animated' key and non-list values should be skipped."""
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'images': [{'filename': 'test.png', 'type': 'output'}],
|
||||||
|
'animated': [True], # Should skip due to key name
|
||||||
|
'metadata': 'string', # Should skip due to non-list
|
||||||
|
'count': 42 # Should skip due to non-list
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
def test_preview_prefers_type_output(self):
|
||||||
|
"""Items with type='output' should be preferred for preview."""
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'images': [
|
||||||
|
{'filename': 'temp.png', 'type': 'temp'},
|
||||||
|
{'filename': 'output.png', 'type': 'output'}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert count == 2
|
||||||
|
assert preview['filename'] == 'output.png'
|
||||||
|
|
||||||
|
def test_preview_fallback_when_no_output_type(self):
|
||||||
|
"""If no type='output', should use first previewable."""
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'images': [
|
||||||
|
{'filename': 'temp1.png', 'type': 'temp'},
|
||||||
|
{'filename': 'temp2.png', 'type': 'temp'}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert preview['filename'] == 'temp1.png'
|
||||||
|
|
||||||
|
def test_non_previewable_media_types_counted_but_no_preview(self):
|
||||||
|
"""Non-previewable media types should be counted but not used as preview."""
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'latents': [
|
||||||
|
{'filename': 'latent1.safetensors'},
|
||||||
|
{'filename': 'latent2.safetensors'}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert count == 2
|
||||||
|
assert preview is None
|
||||||
|
|
||||||
|
def test_previewable_media_types(self):
|
||||||
|
"""Images, video, and audio media types should be previewable."""
|
||||||
|
for media_type in ['images', 'video', 'audio']:
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
media_type: [{'filename': 'test.file', 'type': 'output'}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert preview is not None, f"{media_type} should be previewable"
|
||||||
|
|
||||||
|
def test_3d_files_previewable(self):
|
||||||
|
"""3D file extensions should be previewable."""
|
||||||
|
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'files': [{'filename': f'model{ext}', 'type': 'output'}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert preview is not None, f"3D file {ext} should be previewable"
|
||||||
|
|
||||||
|
def test_format_mime_type_previewable(self):
|
||||||
|
"""Files with video/ or audio/ format should be previewable."""
|
||||||
|
for fmt in ['video/x-custom', 'audio/x-custom']:
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'files': [{'filename': 'file.custom', 'format': fmt, 'type': 'output'}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert preview is not None, f"Format {fmt} should be previewable"
|
||||||
|
|
||||||
|
def test_preview_enriched_with_node_metadata(self):
|
||||||
|
"""Preview should include nodeId, mediaType, and original fields."""
|
||||||
|
outputs = {
|
||||||
|
'node123': {
|
||||||
|
'images': [{'filename': 'test.png', 'type': 'output', 'subfolder': 'outputs'}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert preview['nodeId'] == 'node123'
|
||||||
|
assert preview['mediaType'] == 'images'
|
||||||
|
assert preview['subfolder'] == 'outputs'
|
||||||
|
|
||||||
|
|
||||||
|
class TestApplySorting:
|
||||||
|
"""Unit tests for apply_sorting()"""
|
||||||
|
|
||||||
|
def test_sort_by_create_time_desc(self):
|
||||||
|
"""Default sort by create_time descending."""
|
||||||
|
jobs = [
|
||||||
|
{'id': 'a', 'create_time': 100},
|
||||||
|
{'id': 'b', 'create_time': 300},
|
||||||
|
{'id': 'c', 'create_time': 200},
|
||||||
|
]
|
||||||
|
result = apply_sorting(jobs, 'created_at', 'desc')
|
||||||
|
assert [j['id'] for j in result] == ['b', 'c', 'a']
|
||||||
|
|
||||||
|
def test_sort_by_create_time_asc(self):
|
||||||
|
"""Sort by create_time ascending."""
|
||||||
|
jobs = [
|
||||||
|
{'id': 'a', 'create_time': 100},
|
||||||
|
{'id': 'b', 'create_time': 300},
|
||||||
|
{'id': 'c', 'create_time': 200},
|
||||||
|
]
|
||||||
|
result = apply_sorting(jobs, 'created_at', 'asc')
|
||||||
|
assert [j['id'] for j in result] == ['a', 'c', 'b']
|
||||||
|
|
||||||
|
def test_sort_by_execution_duration(self):
|
||||||
|
"""Sort by execution_duration should order by duration."""
|
||||||
|
jobs = [
|
||||||
|
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, # 5s
|
||||||
|
{'id': 'b', 'create_time': 300, 'execution_start_time': 300, 'execution_end_time': 1300}, # 1s
|
||||||
|
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, # 3s
|
||||||
|
]
|
||||||
|
result = apply_sorting(jobs, 'execution_duration', 'desc')
|
||||||
|
assert [j['id'] for j in result] == ['a', 'c', 'b']
|
||||||
|
|
||||||
|
def test_sort_with_none_values(self):
|
||||||
|
"""Jobs with None values should sort as 0."""
|
||||||
|
jobs = [
|
||||||
|
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100},
|
||||||
|
{'id': 'b', 'create_time': 300, 'execution_start_time': None, 'execution_end_time': None},
|
||||||
|
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200},
|
||||||
|
]
|
||||||
|
result = apply_sorting(jobs, 'execution_duration', 'asc')
|
||||||
|
assert result[0]['id'] == 'b' # None treated as 0, comes first
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeQueueItem:
|
||||||
|
"""Unit tests for normalize_queue_item()"""
|
||||||
|
|
||||||
|
def test_basic_normalization(self):
|
||||||
|
"""Queue item should be normalized to job dict."""
|
||||||
|
item = (
|
||||||
|
10, # priority/number
|
||||||
|
'prompt-123', # prompt_id
|
||||||
|
{'nodes': {}}, # prompt
|
||||||
|
{
|
||||||
|
'create_time': 1234567890,
|
||||||
|
'extra_pnginfo': {'workflow': {'id': 'workflow-abc'}}
|
||||||
|
}, # extra_data
|
||||||
|
['node1'], # outputs_to_execute
|
||||||
|
)
|
||||||
|
job = normalize_queue_item(item, JobStatus.PENDING)
|
||||||
|
|
||||||
|
assert job['id'] == 'prompt-123'
|
||||||
|
assert job['status'] == 'pending'
|
||||||
|
assert job['priority'] == 10
|
||||||
|
assert job['create_time'] == 1234567890
|
||||||
|
assert 'execution_start_time' not in job
|
||||||
|
assert 'execution_end_time' not in job
|
||||||
|
assert 'execution_error' not in job
|
||||||
|
assert 'preview_output' not in job
|
||||||
|
assert job['outputs_count'] == 0
|
||||||
|
assert job['workflow_id'] == 'workflow-abc'
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeHistoryItem:
|
||||||
|
"""Unit tests for normalize_history_item()"""
|
||||||
|
|
||||||
|
def test_completed_job(self):
|
||||||
|
"""Completed history item should have correct status and times from messages."""
|
||||||
|
history_item = {
|
||||||
|
'prompt': (
|
||||||
|
5, # priority
|
||||||
|
'prompt-456',
|
||||||
|
{'nodes': {}},
|
||||||
|
{
|
||||||
|
'create_time': 1234567890000,
|
||||||
|
'extra_pnginfo': {'workflow': {'id': 'workflow-xyz'}}
|
||||||
|
},
|
||||||
|
['node1'],
|
||||||
|
),
|
||||||
|
'status': {
|
||||||
|
'status_str': 'success',
|
||||||
|
'completed': True,
|
||||||
|
'messages': [
|
||||||
|
('execution_start', {'prompt_id': 'prompt-456', 'timestamp': 1234567890500}),
|
||||||
|
('execution_success', {'prompt_id': 'prompt-456', 'timestamp': 1234567893000}),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'outputs': {},
|
||||||
|
}
|
||||||
|
job = normalize_history_item('prompt-456', history_item)
|
||||||
|
|
||||||
|
assert job['id'] == 'prompt-456'
|
||||||
|
assert job['status'] == 'completed'
|
||||||
|
assert job['priority'] == 5
|
||||||
|
assert job['execution_start_time'] == 1234567890500
|
||||||
|
assert job['execution_end_time'] == 1234567893000
|
||||||
|
assert job['workflow_id'] == 'workflow-xyz'
|
||||||
|
|
||||||
|
def test_failed_job(self):
|
||||||
|
"""Failed history item should have failed status and error from messages."""
|
||||||
|
history_item = {
|
||||||
|
'prompt': (
|
||||||
|
5,
|
||||||
|
'prompt-789',
|
||||||
|
{'nodes': {}},
|
||||||
|
{'create_time': 1234567890000},
|
||||||
|
['node1'],
|
||||||
|
),
|
||||||
|
'status': {
|
||||||
|
'status_str': 'error',
|
||||||
|
'completed': False,
|
||||||
|
'messages': [
|
||||||
|
('execution_start', {'prompt_id': 'prompt-789', 'timestamp': 1234567890500}),
|
||||||
|
('execution_error', {
|
||||||
|
'prompt_id': 'prompt-789',
|
||||||
|
'node_id': '5',
|
||||||
|
'node_type': 'KSampler',
|
||||||
|
'exception_message': 'CUDA out of memory',
|
||||||
|
'exception_type': 'RuntimeError',
|
||||||
|
'traceback': ['Traceback...', 'RuntimeError: CUDA out of memory'],
|
||||||
|
'timestamp': 1234567891000,
|
||||||
|
})
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'outputs': {},
|
||||||
|
}
|
||||||
|
|
||||||
|
job = normalize_history_item('prompt-789', history_item)
|
||||||
|
assert job['status'] == 'failed'
|
||||||
|
assert job['execution_start_time'] == 1234567890500
|
||||||
|
assert job['execution_end_time'] == 1234567891000
|
||||||
|
assert job['execution_error']['node_id'] == '5'
|
||||||
|
assert job['execution_error']['node_type'] == 'KSampler'
|
||||||
|
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
|
||||||
|
|
||||||
|
def test_include_outputs(self):
|
||||||
|
"""When include_outputs=True, should include full output data."""
|
||||||
|
history_item = {
|
||||||
|
'prompt': (
|
||||||
|
5,
|
||||||
|
'prompt-123',
|
||||||
|
{'nodes': {'1': {}}},
|
||||||
|
{'create_time': 1234567890, 'client_id': 'abc'},
|
||||||
|
['node1'],
|
||||||
|
),
|
||||||
|
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||||
|
'outputs': {'node1': {'images': [{'filename': 'test.png'}]}},
|
||||||
|
}
|
||||||
|
job = normalize_history_item('prompt-123', history_item, include_outputs=True)
|
||||||
|
|
||||||
|
assert 'outputs' in job
|
||||||
|
assert 'workflow' in job
|
||||||
|
assert 'execution_status' in job
|
||||||
|
assert job['outputs'] == {'node1': {'images': [{'filename': 'test.png'}]}}
|
||||||
|
assert job['workflow'] == {
|
||||||
|
'prompt': {'nodes': {'1': {}}},
|
||||||
|
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user