mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 19:42:59 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
fc5703c468
@ -431,7 +431,7 @@ async def upload_video_to_comfyapi(
|
|||||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error getting video duration: {e}")
|
logging.error("Error getting video duration: %s", str(e))
|
||||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||||
|
|
||||||
upload_mime_type = f"video/{container.value.lower()}"
|
upload_mime_type = f"video/{container.value.lower()}"
|
||||||
|
|||||||
@ -359,10 +359,10 @@ class ApiClient:
|
|||||||
if params:
|
if params:
|
||||||
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
||||||
|
|
||||||
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
logging.debug("[DEBUG] Request Headers: %s", request_headers)
|
||||||
logging.debug(f"[DEBUG] Files: {files}")
|
logging.debug("[DEBUG] Files: %s", files)
|
||||||
logging.debug(f"[DEBUG] Params: {params}")
|
logging.debug("[DEBUG] Params: %s", params)
|
||||||
logging.debug(f"[DEBUG] Data: {data}")
|
logging.debug("[DEBUG] Data: %s", data)
|
||||||
|
|
||||||
if content_type == "application/x-www-form-urlencoded":
|
if content_type == "application/x-www-form-urlencoded":
|
||||||
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
||||||
@ -592,9 +592,9 @@ class ApiClient:
|
|||||||
error_message=f"HTTP Error {exc.status}",
|
error_message=f"HTTP Error {exc.status}",
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
|
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
|
||||||
if response_content:
|
if response_content:
|
||||||
logging.debug(f"[DEBUG] Response content: {response_content}")
|
logging.debug("[DEBUG] Response content: %s", response_content)
|
||||||
|
|
||||||
# Retry if eligible
|
# Retry if eligible
|
||||||
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
||||||
@ -738,11 +738,9 @@ class SynchronousOperation(Generic[T, R]):
|
|||||||
if isinstance(v, Enum):
|
if isinstance(v, Enum):
|
||||||
request_dict[k] = v.value
|
request_dict[k] = v.value
|
||||||
|
|
||||||
logging.debug(
|
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
|
||||||
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
|
||||||
)
|
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
|
||||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
|
||||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
|
||||||
|
|
||||||
response_json = await client.request(
|
response_json = await client.request(
|
||||||
self.endpoint.method.value,
|
self.endpoint.method.value,
|
||||||
@ -757,11 +755,11 @@ class SynchronousOperation(Generic[T, R]):
|
|||||||
logging.debug("=" * 50)
|
logging.debug("=" * 50)
|
||||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||||
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
|
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
|
||||||
logging.debug("=" * 50)
|
logging.debug("=" * 50)
|
||||||
|
|
||||||
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
||||||
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
|
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
|
||||||
return parsed_response
|
return parsed_response
|
||||||
finally:
|
finally:
|
||||||
if owns_client:
|
if owns_client:
|
||||||
@ -877,7 +875,7 @@ class PollingOperation(Generic[T, R]):
|
|||||||
status = TaskStatus.PENDING
|
status = TaskStatus.PENDING
|
||||||
for poll_count in range(1, self.max_poll_attempts + 1):
|
for poll_count in range(1, self.max_poll_attempts + 1):
|
||||||
try:
|
try:
|
||||||
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
|
||||||
|
|
||||||
request_dict = (
|
request_dict = (
|
||||||
None if self.request is None else self.request.model_dump(exclude_none=True)
|
None if self.request is None else self.request.model_dump(exclude_none=True)
|
||||||
@ -885,10 +883,13 @@ class PollingOperation(Generic[T, R]):
|
|||||||
|
|
||||||
if poll_count == 1:
|
if poll_count == 1:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
|
"[DEBUG] Poll Request: %s %s",
|
||||||
|
self.poll_endpoint.method.value,
|
||||||
|
self.poll_endpoint.path,
|
||||||
)
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
|
"[DEBUG] Poll Request Data: %s",
|
||||||
|
json.dumps(request_dict, indent=2) if request_dict else "None",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query task status
|
# Query task status
|
||||||
@ -903,7 +904,7 @@ class PollingOperation(Generic[T, R]):
|
|||||||
|
|
||||||
# Check if task is complete
|
# Check if task is complete
|
||||||
status = self._check_task_status(response_obj)
|
status = self._check_task_status(response_obj)
|
||||||
logging.debug(f"[DEBUG] Task Status: {status}")
|
logging.debug("[DEBUG] Task Status: %s", status)
|
||||||
|
|
||||||
# If progress extractor is provided, extract progress
|
# If progress extractor is provided, extract progress
|
||||||
if self.progress_extractor:
|
if self.progress_extractor:
|
||||||
@ -917,7 +918,7 @@ class PollingOperation(Generic[T, R]):
|
|||||||
result_url = self.result_url_extractor(response_obj)
|
result_url = self.result_url_extractor(response_obj)
|
||||||
if result_url:
|
if result_url:
|
||||||
message = f"Result URL: {result_url}"
|
message = f"Result URL: {result_url}"
|
||||||
logging.debug(f"[DEBUG] {message}")
|
logging.debug("[DEBUG] %s", message)
|
||||||
self._display_text_on_node(message)
|
self._display_text_on_node(message)
|
||||||
self.final_response = response_obj
|
self.final_response = response_obj
|
||||||
if self.progress_extractor:
|
if self.progress_extractor:
|
||||||
@ -925,7 +926,7 @@ class PollingOperation(Generic[T, R]):
|
|||||||
return self.final_response
|
return self.final_response
|
||||||
if status == TaskStatus.FAILED:
|
if status == TaskStatus.FAILED:
|
||||||
message = f"Task failed: {json.dumps(resp)}"
|
message = f"Task failed: {json.dumps(resp)}"
|
||||||
logging.error(f"[DEBUG] {message}")
|
logging.error("[DEBUG] %s", message)
|
||||||
raise Exception(message)
|
raise Exception(message)
|
||||||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
||||||
# Task pending – wait
|
# Task pending – wait
|
||||||
@ -939,7 +940,12 @@ class PollingOperation(Generic[T, R]):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
|
logging.warning(
|
||||||
|
"Network error (%s/%s): %s",
|
||||||
|
consecutive_errors,
|
||||||
|
max_consecutive_errors,
|
||||||
|
str(e),
|
||||||
|
)
|
||||||
await asyncio.sleep(self.poll_interval)
|
await asyncio.sleep(self.poll_interval)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# For other errors, increment count and potentially abort
|
# For other errors, increment count and potentially abort
|
||||||
@ -949,10 +955,13 @@ class PollingOperation(Generic[T, R]):
|
|||||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
logging.error("[DEBUG] Polling error: %s", str(e))
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
|
||||||
f"Will retry in {self.poll_interval} seconds."
|
poll_count,
|
||||||
|
self.max_poll_attempts,
|
||||||
|
str(e),
|
||||||
|
self.poll_interval,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(self.poll_interval)
|
await asyncio.sleep(self.poll_interval)
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ def get_log_directory():
|
|||||||
try:
|
try:
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating API log directory {log_dir}: {e}")
|
logger.error("Error creating API log directory %s: %s", log_dir, str(e))
|
||||||
# Fallback to base temp directory if sub-directory creation fails
|
# Fallback to base temp directory if sub-directory creation fails
|
||||||
return base_temp_dir
|
return base_temp_dir
|
||||||
return log_dir
|
return log_dir
|
||||||
@ -122,9 +122,9 @@ def log_request_response(
|
|||||||
try:
|
try:
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(log_content))
|
f.write("\n".join(log_content))
|
||||||
logger.debug(f"API log saved to: {filepath}")
|
logger.debug("API log saved to: %s", filepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error writing API log to {filepath}: {e}")
|
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -296,7 +296,7 @@ def validate_video_result_response(response) -> None:
|
|||||||
"""Validates that the Kling task result contains a video."""
|
"""Validates that the Kling task result contains a video."""
|
||||||
if not is_valid_video_response(response):
|
if not is_valid_video_response(response):
|
||||||
error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response."
|
error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response."
|
||||||
logging.error(f"Error: {error_msg}.\nResponse: {response}")
|
logging.error("Error: %s.\nResponse: %s", error_msg, response)
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
|
|
||||||
@ -304,7 +304,7 @@ def validate_image_result_response(response) -> None:
|
|||||||
"""Validates that the Kling task result contains an image."""
|
"""Validates that the Kling task result contains an image."""
|
||||||
if not is_valid_image_response(response):
|
if not is_valid_image_response(response):
|
||||||
error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response."
|
error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response."
|
||||||
logging.error(f"Error: {error_msg}.\nResponse: {response}")
|
logging.error("Error: %s.\nResponse: %s", error_msg, response)
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -500,7 +500,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||||
)
|
)
|
||||||
logging.info(f"Generated video URL: {file_url}")
|
logging.info("Generated video URL: %s", file_url)
|
||||||
if cls.hidden.unique_id:
|
if cls.hidden.unique_id:
|
||||||
if hasattr(file_result.file, "backup_download_url"):
|
if hasattr(file_result.file, "backup_download_url"):
|
||||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||||
|
|||||||
@ -237,7 +237,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
audio_stream = None
|
audio_stream = None
|
||||||
|
|
||||||
for stream in input_container.streams:
|
for stream in input_container.streams:
|
||||||
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
|
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||||
if isinstance(stream, av.VideoStream):
|
if isinstance(stream, av.VideoStream):
|
||||||
# Create output video stream with same parameters
|
# Create output video stream with same parameters
|
||||||
video_stream = output_container.add_stream(
|
video_stream = output_container.add_stream(
|
||||||
@ -247,7 +247,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
video_stream.height = stream.height
|
video_stream.height = stream.height
|
||||||
video_stream.pix_fmt = "yuv420p"
|
video_stream.pix_fmt = "yuv420p"
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
|
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
|
||||||
)
|
)
|
||||||
elif isinstance(stream, av.AudioStream):
|
elif isinstance(stream, av.AudioStream):
|
||||||
# Create output audio stream with same parameters
|
# Create output audio stream with same parameters
|
||||||
@ -256,9 +256,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
)
|
)
|
||||||
audio_stream.sample_rate = stream.sample_rate
|
audio_stream.sample_rate = stream.sample_rate
|
||||||
audio_stream.layout = stream.layout
|
audio_stream.layout = stream.layout
|
||||||
logging.info(
|
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||||
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate target frame count that's divisible by 16
|
# Calculate target frame count that's divisible by 16
|
||||||
fps = input_container.streams.video[0].average_rate
|
fps = input_container.streams.video[0].average_rate
|
||||||
@ -288,9 +286,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
for packet in video_stream.encode():
|
for packet in video_stream.encode():
|
||||||
output_container.mux(packet)
|
output_container.mux(packet)
|
||||||
|
|
||||||
logging.info(
|
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||||
f"Encoded {frame_count} video frames (target: {target_frames})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode and re-encode audio frames
|
# Decode and re-encode audio frames
|
||||||
if audio_stream:
|
if audio_stream:
|
||||||
@ -308,7 +304,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
for packet in audio_stream.encode():
|
for packet in audio_stream.encode():
|
||||||
output_container.mux(packet)
|
output_container.mux(packet)
|
||||||
|
|
||||||
logging.info(f"Encoded {audio_frame_count} audio frames")
|
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||||
|
|
||||||
# Close containers
|
# Close containers
|
||||||
output_container.close()
|
output_container.close()
|
||||||
|
|||||||
@ -172,16 +172,16 @@ async def create_generate_task(
|
|||||||
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
||||||
subscription_key = response.jobs.subscription_key
|
subscription_key = response.jobs.subscription_key
|
||||||
task_uuid = response.uuid
|
task_uuid = response.uuid
|
||||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid)
|
||||||
return task_uuid, subscription_key
|
return task_uuid, subscription_key
|
||||||
|
|
||||||
|
|
||||||
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||||
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
||||||
status_list = [str(job.status) for job in response.jobs]
|
status_list = [str(job.status) for job in response.jobs]
|
||||||
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
|
logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list)
|
||||||
if any(job.status == JobStatus.Failed for job in response.jobs):
|
if any(job.status == JobStatus.Failed for job in response.jobs):
|
||||||
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
|
logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list)
|
||||||
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
||||||
if all_done:
|
if all_done:
|
||||||
return "DONE"
|
return "DONE"
|
||||||
@ -235,7 +235,7 @@ async def download_files(url_list, task_uuid):
|
|||||||
file_path = os.path.join(save_path, file_name)
|
file_path = os.path.join(save_path, file_name)
|
||||||
if file_path.endswith(".glb"):
|
if file_path.endswith(".glb"):
|
||||||
model_file_path = file_path
|
model_file_path = file_path
|
||||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
||||||
max_retries = 5
|
max_retries = 5
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
@ -246,7 +246,7 @@ async def download_files(url_list, task_uuid):
|
|||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
logging.info("Retrying...")
|
logging.info("Retrying...")
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|||||||
@ -215,7 +215,7 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
initial_response = await initial_operation.execute()
|
initial_response = await initial_operation.execute()
|
||||||
operation_name = initial_response.name
|
operation_name = initial_response.name
|
||||||
|
|
||||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
logging.info("Veo generation started with operation name: %s", operation_name)
|
||||||
|
|
||||||
# Define status extractor function
|
# Define status extractor function
|
||||||
def status_extractor(response):
|
def status_extractor(response):
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def resize_mask(mask, shape):
|
def resize_mask(mask, shape):
|
||||||
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
||||||
@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
|||||||
return out_image, out_alpha
|
return out_image, out_alpha
|
||||||
|
|
||||||
|
|
||||||
class PorterDuffImageComposite:
|
class PorterDuffImageComposite(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="PorterDuffImageComposite",
|
||||||
"source": ("IMAGE",),
|
display_name="Porter-Duff Image Composite",
|
||||||
"source_alpha": ("MASK",),
|
category="mask/compositing",
|
||||||
"destination": ("IMAGE",),
|
inputs=[
|
||||||
"destination_alpha": ("MASK",),
|
io.Image.Input("source"),
|
||||||
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
io.Mask.Input("source_alpha"),
|
||||||
},
|
io.Image.Input("destination"),
|
||||||
}
|
io.Mask.Input("destination_alpha"),
|
||||||
|
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
io.Mask.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
@classmethod
|
||||||
FUNCTION = "composite"
|
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
|
||||||
CATEGORY = "mask/compositing"
|
|
||||||
|
|
||||||
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
|
|
||||||
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
||||||
out_images = []
|
out_images = []
|
||||||
out_alphas = []
|
out_alphas = []
|
||||||
@ -150,45 +157,48 @@ class PorterDuffImageComposite:
|
|||||||
out_images.append(out_image)
|
out_images.append(out_image)
|
||||||
out_alphas.append(out_alpha.squeeze(2))
|
out_alphas.append(out_alpha.squeeze(2))
|
||||||
|
|
||||||
result = (torch.stack(out_images), torch.stack(out_alphas))
|
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class SplitImageWithAlpha:
|
class SplitImageWithAlpha(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="SplitImageWithAlpha",
|
||||||
"image": ("IMAGE",),
|
display_name="Split Image with Alpha",
|
||||||
}
|
category="mask/compositing",
|
||||||
}
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
io.Mask.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask/compositing"
|
@classmethod
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
|
||||||
FUNCTION = "split_image_with_alpha"
|
|
||||||
|
|
||||||
def split_image_with_alpha(self, image: torch.Tensor):
|
|
||||||
out_images = [i[:,:,:3] for i in image]
|
out_images = [i[:,:,:3] for i in image]
|
||||||
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
||||||
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class JoinImageWithAlpha:
|
class JoinImageWithAlpha(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="JoinImageWithAlpha",
|
||||||
"image": ("IMAGE",),
|
display_name="Join Image with Alpha",
|
||||||
"alpha": ("MASK",),
|
category="mask/compositing",
|
||||||
}
|
inputs=[
|
||||||
}
|
io.Image.Input("image"),
|
||||||
|
io.Mask.Input("alpha"),
|
||||||
|
],
|
||||||
|
outputs=[io.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask/compositing"
|
@classmethod
|
||||||
RETURN_TYPES = ("IMAGE",)
|
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||||
FUNCTION = "join_image_with_alpha"
|
|
||||||
|
|
||||||
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
|
|
||||||
batch_size = min(len(image), len(alpha))
|
batch_size = min(len(image), len(alpha))
|
||||||
out_images = []
|
out_images = []
|
||||||
|
|
||||||
@ -196,19 +206,18 @@ class JoinImageWithAlpha:
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||||
|
|
||||||
result = (torch.stack(out_images),)
|
return io.NodeOutput(torch.stack(out_images))
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class CompositingExtension(ComfyExtension):
|
||||||
"PorterDuffImageComposite": PorterDuffImageComposite,
|
@override
|
||||||
"SplitImageWithAlpha": SplitImageWithAlpha,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"JoinImageWithAlpha": JoinImageWithAlpha,
|
return [
|
||||||
}
|
PorterDuffImageComposite,
|
||||||
|
SplitImageWithAlpha,
|
||||||
|
JoinImageWithAlpha,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
async def comfy_entrypoint() -> CompositingExtension:
|
||||||
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
return CompositingExtension()
|
||||||
"SplitImageWithAlpha": "Split Image with Alpha",
|
|
||||||
"JoinImageWithAlpha": "Join Image with Alpha",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -2,6 +2,8 @@ import comfy.utils
|
|||||||
import comfy_extras.nodes_post_processing
|
import comfy_extras.nodes_post_processing
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
import nodes
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||||
@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
|||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
|
||||||
class LatentAdd:
|
class LatentAdd(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
return io.Schema(
|
||||||
|
node_id="LatentAdd",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -31,19 +39,25 @@ class LatentAdd:
|
|||||||
|
|
||||||
s2 = reshape_latent_to(s1.shape, s2)
|
s2 = reshape_latent_to(s1.shape, s2)
|
||||||
samples_out["samples"] = s1 + s2
|
samples_out["samples"] = s1 + s2
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentSubtract:
|
class LatentSubtract(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
return io.Schema(
|
||||||
|
node_id="LatentSubtract",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -51,41 +65,49 @@ class LatentSubtract:
|
|||||||
|
|
||||||
s2 = reshape_latent_to(s1.shape, s2)
|
s2 = reshape_latent_to(s1.shape, s2)
|
||||||
samples_out["samples"] = s1 - s2
|
samples_out["samples"] = s1 - s2
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentMultiply:
|
class LatentMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return io.Schema(
|
||||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
node_id="LatentMultiply",
|
||||||
}}
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, multiplier) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples, multiplier):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
samples_out["samples"] = s1 * multiplier
|
samples_out["samples"] = s1 * multiplier
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentInterpolate:
|
class LatentInterpolate(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",),
|
return io.Schema(
|
||||||
"samples2": ("LATENT",),
|
node_id="LatentInterpolate",
|
||||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
category="latent/advanced",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2, ratio):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -104,19 +126,26 @@ class LatentInterpolate:
|
|||||||
st = torch.nan_to_num(t / mt)
|
st = torch.nan_to_num(t / mt)
|
||||||
|
|
||||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentConcat:
|
class LatentConcat(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
|
return io.Schema(
|
||||||
|
node_id="LatentConcat",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2, dim):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -136,22 +165,27 @@ class LatentConcat:
|
|||||||
dim = -3
|
dim = -3
|
||||||
|
|
||||||
samples_out["samples"] = torch.cat(c, dim=dim)
|
samples_out["samples"] = torch.cat(c, dim=dim)
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentCut:
|
class LatentCut(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"samples": ("LATENT",),
|
return io.Schema(
|
||||||
"dim": (["x", "y", "t"], ),
|
node_id="LatentCut",
|
||||||
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
|
category="latent/advanced",
|
||||||
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Combo.Input("dim", options=["x", "y", "t"]),
|
||||||
|
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples, dim, index, amount):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
@ -171,19 +205,25 @@ class LatentCut:
|
|||||||
amount = min(-index, amount)
|
amount = min(-index, amount)
|
||||||
|
|
||||||
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentBatch:
|
class LatentBatch(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
return io.Schema(
|
||||||
|
node_id="LatentBatch",
|
||||||
|
category="latent/batch",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "batch"
|
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/batch"
|
|
||||||
|
|
||||||
def batch(self, samples1, samples2):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
s2 = samples2["samples"]
|
s2 = samples2["samples"]
|
||||||
@ -192,20 +232,25 @@ class LatentBatch:
|
|||||||
s = torch.cat((s1, s2), dim=0)
|
s = torch.cat((s1, s2), dim=0)
|
||||||
samples_out["samples"] = s
|
samples_out["samples"] = s
|
||||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentBatchSeedBehavior:
|
class LatentBatchSeedBehavior(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return io.Schema(
|
||||||
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
|
node_id="LatentBatchSeedBehavior",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, seed_behavior) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples, seed_behavior):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
latent = samples["samples"]
|
latent = samples["samples"]
|
||||||
if seed_behavior == "random":
|
if seed_behavior == "random":
|
||||||
@ -215,41 +260,50 @@ class LatentBatchSeedBehavior:
|
|||||||
batch_number = samples_out.get("batch_index", [0])[0]
|
batch_number = samples_out.get("batch_index", [0])[0]
|
||||||
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
||||||
|
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentApplyOperation:
|
class LatentApplyOperation(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return io.Schema(
|
||||||
"operation": ("LATENT_OPERATION",),
|
node_id="LatentApplyOperation",
|
||||||
}}
|
category="latent/advanced/operations",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.LatentOperation.Input("operation"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, operation) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, samples, operation):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
samples_out["samples"] = operation(latent=s1)
|
samples_out["samples"] = operation(latent=s1)
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentApplyOperationCFG:
|
class LatentApplyOperationCFG(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"operation": ("LATENT_OPERATION",),
|
node_id="LatentApplyOperationCFG",
|
||||||
}}
|
category="latent/advanced/operations",
|
||||||
RETURN_TYPES = ("MODEL",)
|
is_experimental=True,
|
||||||
FUNCTION = "patch"
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.LatentOperation.Input("operation"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
@classmethod
|
||||||
EXPERIMENTAL = True
|
def execute(cls, model, operation) -> io.NodeOutput:
|
||||||
|
|
||||||
def patch(self, model, operation):
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
def pre_cfg_function(args):
|
def pre_cfg_function(args):
|
||||||
@ -261,21 +315,25 @@ class LatentApplyOperationCFG:
|
|||||||
return conds_out
|
return conds_out
|
||||||
|
|
||||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class LatentOperationTonemapReinhard:
|
class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
return io.Schema(
|
||||||
}}
|
node_id="LatentOperationTonemapReinhard",
|
||||||
|
category="latent/advanced/operations",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.LatentOperation.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, multiplier) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, multiplier):
|
|
||||||
def tonemap_reinhard(latent, **kwargs):
|
def tonemap_reinhard(latent, **kwargs):
|
||||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||||
normalized_latent = latent / latent_vector_magnitude
|
normalized_latent = latent / latent_vector_magnitude
|
||||||
@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard:
|
|||||||
new_magnitude *= top
|
new_magnitude *= top
|
||||||
|
|
||||||
return normalized_latent * new_magnitude
|
return normalized_latent * new_magnitude
|
||||||
return (tonemap_reinhard,)
|
return io.NodeOutput(tonemap_reinhard)
|
||||||
|
|
||||||
class LatentOperationSharpen:
|
class LatentOperationSharpen(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"sharpen_radius": ("INT", {
|
node_id="LatentOperationSharpen",
|
||||||
"default": 9,
|
category="latent/advanced/operations",
|
||||||
"min": 1,
|
is_experimental=True,
|
||||||
"max": 31,
|
inputs=[
|
||||||
"step": 1
|
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
|
||||||
}),
|
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
||||||
"sigma": ("FLOAT", {
|
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
|
||||||
"default": 1.0,
|
],
|
||||||
"min": 0.1,
|
outputs=[
|
||||||
"max": 10.0,
|
io.LatentOperation.Output(),
|
||||||
"step": 0.1
|
],
|
||||||
}),
|
)
|
||||||
"alpha": ("FLOAT", {
|
|
||||||
"default": 0.1,
|
|
||||||
"min": 0.0,
|
|
||||||
"max": 5.0,
|
|
||||||
"step": 0.01
|
|
||||||
}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, sharpen_radius, sigma, alpha):
|
|
||||||
def sharpen(latent, **kwargs):
|
def sharpen(latent, **kwargs):
|
||||||
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
||||||
normalized_latent = latent / luminance
|
normalized_latent = latent / luminance
|
||||||
@ -340,19 +386,27 @@ class LatentOperationSharpen:
|
|||||||
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||||
|
|
||||||
return luminance * sharpened
|
return luminance * sharpened
|
||||||
return (sharpen,)
|
return io.NodeOutput(sharpen)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"LatentAdd": LatentAdd,
|
class LatentExtension(ComfyExtension):
|
||||||
"LatentSubtract": LatentSubtract,
|
@override
|
||||||
"LatentMultiply": LatentMultiply,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"LatentInterpolate": LatentInterpolate,
|
return [
|
||||||
"LatentConcat": LatentConcat,
|
LatentAdd,
|
||||||
"LatentCut": LatentCut,
|
LatentSubtract,
|
||||||
"LatentBatch": LatentBatch,
|
LatentMultiply,
|
||||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
LatentInterpolate,
|
||||||
"LatentApplyOperation": LatentApplyOperation,
|
LatentConcat,
|
||||||
"LatentApplyOperationCFG": LatentApplyOperationCFG,
|
LatentCut,
|
||||||
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
|
LatentBatch,
|
||||||
"LatentOperationSharpen": LatentOperationSharpen,
|
LatentBatchSeedBehavior,
|
||||||
}
|
LatentApplyOperation,
|
||||||
|
LatentApplyOperationCFG,
|
||||||
|
LatentOperationTonemapReinhard,
|
||||||
|
LatentOperationSharpen,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LatentExtension:
|
||||||
|
return LatentExtension()
|
||||||
|
|||||||
@ -5,6 +5,8 @@ import folder_paths
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
|||||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
return output_sd
|
return output_sd
|
||||||
|
|
||||||
class LoraSave:
|
class LoraSave(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoraSave",
|
||||||
|
display_name="Extract and Save Lora",
|
||||||
|
category="_for_testing",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||||
|
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
|
||||||
|
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
|
||||||
|
io.Boolean.Input("bias_diff", default=True),
|
||||||
|
io.Model.Input(
|
||||||
|
"model_diff",
|
||||||
|
tooltip="The ModelSubtract output to be converted to a lora.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
io.Clip.Input(
|
||||||
|
"text_encoder_diff",
|
||||||
|
tooltip="The CLIPSubtract output to be converted to a lora.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
|
||||||
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
|
||||||
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
"lora_type": (tuple(LORA_TYPES.keys()),),
|
|
||||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
|
||||||
},
|
|
||||||
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
|
||||||
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
|
||||||
|
|
||||||
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
|
||||||
if model_diff is None and text_encoder_diff is None:
|
if model_diff is None and text_encoder_diff is None:
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
lora_type = LORA_TYPES.get(lora_type)
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||||
|
|
||||||
output_sd = {}
|
output_sd = {}
|
||||||
if model_diff is not None:
|
if model_diff is not None:
|
||||||
@ -108,12 +118,16 @@ class LoraSave:
|
|||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"LoraSave": LoraSave
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
class LoraSaveExtension(ComfyExtension):
|
||||||
"LoraSave": "Extract and Save Lora"
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
LoraSave,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LoraSaveExtension:
|
||||||
|
return LoraSaveExtension()
|
||||||
|
|||||||
@ -1,24 +1,33 @@
|
|||||||
|
from typing_extensions import override
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
class PatchModelAddDownscale:
|
|
||||||
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
class PatchModelAddDownscale(io.ComfyNode):
|
||||||
|
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
|
node_id="PatchModelAddDownscale",
|
||||||
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
category="model_patches/unet",
|
||||||
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
inputs=[
|
||||||
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
io.Model.Input("model"),
|
||||||
"downscale_method": (s.upscale_methods,),
|
io.Int.Input("block_number", default=3, min=1, max=32, step=1),
|
||||||
"upscale_method": (s.upscale_methods,),
|
io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
|
||||||
}}
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
|
||||||
FUNCTION = "patch"
|
io.Boolean.Input("downscale_after_skip", default=True),
|
||||||
|
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
|
||||||
|
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "model_patches/unet"
|
@classmethod
|
||||||
|
def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
|
||||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
@ -41,13 +50,21 @@ class PatchModelAddDownscale:
|
|||||||
else:
|
else:
|
||||||
m.set_model_input_block_patch(input_block_patch)
|
m.set_model_input_block_patch(input_block_patch)
|
||||||
m.set_model_output_block_patch(output_block_patch)
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"PatchModelAddDownscale": PatchModelAddDownscale,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
# Sampling
|
# Sampling
|
||||||
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
|
"PatchModelAddDownscale": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ModelDownscaleExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
PatchModelAddDownscale,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ModelDownscaleExtension:
|
||||||
|
return ModelDownscaleExtension()
|
||||||
|
|||||||
@ -61,7 +61,6 @@ messages_control.disable = [
|
|||||||
# next warnings should be fixed in future
|
# next warnings should be fixed in future
|
||||||
"bad-classmethod-argument", # Class method should have 'cls' as first argument
|
"bad-classmethod-argument", # Class method should have 'cls' as first argument
|
||||||
"wrong-import-order", # Standard imports should be placed before third party imports
|
"wrong-import-order", # Standard imports should be placed before third party imports
|
||||||
"logging-fstring-interpolation", # Use lazy % formatting in logging functions
|
|
||||||
"ungrouped-imports",
|
"ungrouped-imports",
|
||||||
"unnecessary-pass",
|
"unnecessary-pass",
|
||||||
"unnecessary-lambda-assignment",
|
"unnecessary-lambda-assignment",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user