mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 11:32:58 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
fc5703c468
@ -431,7 +431,7 @@ async def upload_video_to_comfyapi(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting video duration: {e}")
|
||||
logging.error("Error getting video duration: %s", str(e))
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
|
||||
@ -359,10 +359,10 @@ class ApiClient:
|
||||
if params:
|
||||
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
||||
|
||||
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
||||
logging.debug(f"[DEBUG] Files: {files}")
|
||||
logging.debug(f"[DEBUG] Params: {params}")
|
||||
logging.debug(f"[DEBUG] Data: {data}")
|
||||
logging.debug("[DEBUG] Request Headers: %s", request_headers)
|
||||
logging.debug("[DEBUG] Files: %s", files)
|
||||
logging.debug("[DEBUG] Params: %s", params)
|
||||
logging.debug("[DEBUG] Data: %s", data)
|
||||
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
||||
@ -592,9 +592,9 @@ class ApiClient:
|
||||
error_message=f"HTTP Error {exc.status}",
|
||||
)
|
||||
|
||||
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
|
||||
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
|
||||
if response_content:
|
||||
logging.debug(f"[DEBUG] Response content: {response_content}")
|
||||
logging.debug("[DEBUG] Response content: %s", response_content)
|
||||
|
||||
# Retry if eligible
|
||||
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
||||
@ -738,11 +738,9 @@ class SynchronousOperation(Generic[T, R]):
|
||||
if isinstance(v, Enum):
|
||||
request_dict[k] = v.value
|
||||
|
||||
logging.debug(
|
||||
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
||||
)
|
||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
||||
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
|
||||
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
|
||||
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
|
||||
|
||||
response_json = await client.request(
|
||||
self.endpoint.method.value,
|
||||
@ -757,11 +755,11 @@ class SynchronousOperation(Generic[T, R]):
|
||||
logging.debug("=" * 50)
|
||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
|
||||
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
|
||||
logging.debug("=" * 50)
|
||||
|
||||
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
||||
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
|
||||
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
|
||||
return parsed_response
|
||||
finally:
|
||||
if owns_client:
|
||||
@ -877,7 +875,7 @@ class PollingOperation(Generic[T, R]):
|
||||
status = TaskStatus.PENDING
|
||||
for poll_count in range(1, self.max_poll_attempts + 1):
|
||||
try:
|
||||
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
||||
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
|
||||
|
||||
request_dict = (
|
||||
None if self.request is None else self.request.model_dump(exclude_none=True)
|
||||
@ -885,10 +883,13 @@ class PollingOperation(Generic[T, R]):
|
||||
|
||||
if poll_count == 1:
|
||||
logging.debug(
|
||||
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
|
||||
"[DEBUG] Poll Request: %s %s",
|
||||
self.poll_endpoint.method.value,
|
||||
self.poll_endpoint.path,
|
||||
)
|
||||
logging.debug(
|
||||
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
|
||||
"[DEBUG] Poll Request Data: %s",
|
||||
json.dumps(request_dict, indent=2) if request_dict else "None",
|
||||
)
|
||||
|
||||
# Query task status
|
||||
@ -903,7 +904,7 @@ class PollingOperation(Generic[T, R]):
|
||||
|
||||
# Check if task is complete
|
||||
status = self._check_task_status(response_obj)
|
||||
logging.debug(f"[DEBUG] Task Status: {status}")
|
||||
logging.debug("[DEBUG] Task Status: %s", status)
|
||||
|
||||
# If progress extractor is provided, extract progress
|
||||
if self.progress_extractor:
|
||||
@ -917,7 +918,7 @@ class PollingOperation(Generic[T, R]):
|
||||
result_url = self.result_url_extractor(response_obj)
|
||||
if result_url:
|
||||
message = f"Result URL: {result_url}"
|
||||
logging.debug(f"[DEBUG] {message}")
|
||||
logging.debug("[DEBUG] %s", message)
|
||||
self._display_text_on_node(message)
|
||||
self.final_response = response_obj
|
||||
if self.progress_extractor:
|
||||
@ -925,7 +926,7 @@ class PollingOperation(Generic[T, R]):
|
||||
return self.final_response
|
||||
if status == TaskStatus.FAILED:
|
||||
message = f"Task failed: {json.dumps(resp)}"
|
||||
logging.error(f"[DEBUG] {message}")
|
||||
logging.error("[DEBUG] %s", message)
|
||||
raise Exception(message)
|
||||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
||||
# Task pending – wait
|
||||
@ -939,7 +940,12 @@ class PollingOperation(Generic[T, R]):
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
||||
) from e
|
||||
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
|
||||
logging.warning(
|
||||
"Network error (%s/%s): %s",
|
||||
consecutive_errors,
|
||||
max_consecutive_errors,
|
||||
str(e),
|
||||
)
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
except Exception as e:
|
||||
# For other errors, increment count and potentially abort
|
||||
@ -949,10 +955,13 @@ class PollingOperation(Generic[T, R]):
|
||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
||||
) from e
|
||||
|
||||
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
||||
logging.error("[DEBUG] Polling error: %s", str(e))
|
||||
logging.warning(
|
||||
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
||||
f"Will retry in {self.poll_interval} seconds."
|
||||
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
|
||||
poll_count,
|
||||
self.max_poll_attempts,
|
||||
str(e),
|
||||
self.poll_interval,
|
||||
)
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ def get_log_directory():
|
||||
try:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating API log directory {log_dir}: {e}")
|
||||
logger.error("Error creating API log directory %s: %s", log_dir, str(e))
|
||||
# Fallback to base temp directory if sub-directory creation fails
|
||||
return base_temp_dir
|
||||
return log_dir
|
||||
@ -122,9 +122,9 @@ def log_request_response(
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug(f"API log saved to: {filepath}")
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing API log to {filepath}: {e}")
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -296,7 +296,7 @@ def validate_video_result_response(response) -> None:
|
||||
"""Validates that the Kling task result contains a video."""
|
||||
if not is_valid_video_response(response):
|
||||
error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response."
|
||||
logging.error(f"Error: {error_msg}.\nResponse: {response}")
|
||||
logging.error("Error: %s.\nResponse: %s", error_msg, response)
|
||||
raise Exception(error_msg)
|
||||
|
||||
|
||||
@ -304,7 +304,7 @@ def validate_image_result_response(response) -> None:
|
||||
"""Validates that the Kling task result contains an image."""
|
||||
if not is_valid_image_response(response):
|
||||
error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response."
|
||||
logging.error(f"Error: {error_msg}.\nResponse: {response}")
|
||||
logging.error("Error: %s.\nResponse: %s", error_msg, response)
|
||||
raise Exception(error_msg)
|
||||
|
||||
|
||||
|
||||
@ -500,7 +500,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info(f"Generated video URL: {file_url}")
|
||||
logging.info("Generated video URL: %s", file_url)
|
||||
if cls.hidden.unique_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
|
||||
@ -237,7 +237,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
|
||||
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream(
|
||||
@ -247,7 +247,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
logging.info(
|
||||
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
|
||||
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
|
||||
)
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
@ -256,9 +256,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info(
|
||||
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
|
||||
)
|
||||
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
@ -288,9 +286,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info(
|
||||
f"Encoded {frame_count} video frames (target: {target_frames})"
|
||||
)
|
||||
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
@ -308,7 +304,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info(f"Encoded {audio_frame_count} audio frames")
|
||||
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
|
||||
@ -172,16 +172,16 @@ async def create_generate_task(
|
||||
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
||||
subscription_key = response.jobs.subscription_key
|
||||
task_uuid = response.uuid
|
||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
||||
logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid)
|
||||
return task_uuid, subscription_key
|
||||
|
||||
|
||||
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
||||
status_list = [str(job.status) for job in response.jobs]
|
||||
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
|
||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list)
|
||||
if any(job.status == JobStatus.Failed for job in response.jobs):
|
||||
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
|
||||
logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list)
|
||||
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
||||
if all_done:
|
||||
return "DONE"
|
||||
@ -235,7 +235,7 @@ async def download_files(url_list, task_uuid):
|
||||
file_path = os.path.join(save_path, file_name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
@ -246,7 +246,7 @@ async def download_files(url_list, task_uuid):
|
||||
f.write(chunk)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||
logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
@ -215,7 +215,7 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||
initial_response = await initial_operation.execute()
|
||||
operation_name = initial_response.name
|
||||
|
||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
||||
logging.info("Veo generation started with operation name: %s", operation_name)
|
||||
|
||||
# Define status extractor function
|
||||
def status_extractor(response):
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
from enum import Enum
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def resize_mask(mask, shape):
|
||||
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
||||
@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
||||
return out_image, out_alpha
|
||||
|
||||
|
||||
class PorterDuffImageComposite:
|
||||
class PorterDuffImageComposite(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"source": ("IMAGE",),
|
||||
"source_alpha": ("MASK",),
|
||||
"destination": ("IMAGE",),
|
||||
"destination_alpha": ("MASK",),
|
||||
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
||||
},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PorterDuffImageComposite",
|
||||
display_name="Porter-Duff Image Composite",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
io.Image.Input("source"),
|
||||
io.Mask.Input("source_alpha"),
|
||||
io.Image.Input("destination"),
|
||||
io.Mask.Input("destination_alpha"),
|
||||
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
io.Mask.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "composite"
|
||||
CATEGORY = "mask/compositing"
|
||||
|
||||
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
|
||||
@classmethod
|
||||
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
|
||||
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
||||
out_images = []
|
||||
out_alphas = []
|
||||
@ -150,45 +157,48 @@ class PorterDuffImageComposite:
|
||||
out_images.append(out_image)
|
||||
out_alphas.append(out_alpha.squeeze(2))
|
||||
|
||||
result = (torch.stack(out_images), torch.stack(out_alphas))
|
||||
return result
|
||||
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
|
||||
|
||||
|
||||
class SplitImageWithAlpha:
|
||||
class SplitImageWithAlpha(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SplitImageWithAlpha",
|
||||
display_name="Split Image with Alpha",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
io.Mask.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "mask/compositing"
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "split_image_with_alpha"
|
||||
|
||||
def split_image_with_alpha(self, image: torch.Tensor):
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
|
||||
out_images = [i[:,:,:3] for i in image]
|
||||
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
||||
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||
return result
|
||||
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||
|
||||
|
||||
class JoinImageWithAlpha:
|
||||
class JoinImageWithAlpha(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"alpha": ("MASK",),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="JoinImageWithAlpha",
|
||||
display_name="Join Image with Alpha",
|
||||
category="mask/compositing",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Mask.Input("alpha"),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask/compositing"
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "join_image_with_alpha"
|
||||
|
||||
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||
batch_size = min(len(image), len(alpha))
|
||||
out_images = []
|
||||
|
||||
@ -196,19 +206,18 @@ class JoinImageWithAlpha:
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
||||
result = (torch.stack(out_images),)
|
||||
return result
|
||||
return io.NodeOutput(torch.stack(out_images))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PorterDuffImageComposite": PorterDuffImageComposite,
|
||||
"SplitImageWithAlpha": SplitImageWithAlpha,
|
||||
"JoinImageWithAlpha": JoinImageWithAlpha,
|
||||
}
|
||||
class CompositingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
PorterDuffImageComposite,
|
||||
SplitImageWithAlpha,
|
||||
JoinImageWithAlpha,
|
||||
]
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
||||
"SplitImageWithAlpha": "Split Image with Alpha",
|
||||
"JoinImageWithAlpha": "Join Image with Alpha",
|
||||
}
|
||||
async def comfy_entrypoint() -> CompositingExtension:
|
||||
return CompositingExtension()
|
||||
|
||||
@ -2,6 +2,8 @@ import comfy.utils
|
||||
import comfy_extras.nodes_post_processing
|
||||
import torch
|
||||
import nodes
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
return latent
|
||||
|
||||
|
||||
class LatentAdd:
|
||||
class LatentAdd(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentAdd",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples1, samples2):
|
||||
@classmethod
|
||||
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
@ -31,19 +39,25 @@ class LatentAdd:
|
||||
|
||||
s2 = reshape_latent_to(s1.shape, s2)
|
||||
samples_out["samples"] = s1 + s2
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentSubtract:
|
||||
class LatentSubtract(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentSubtract",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples1, samples2):
|
||||
@classmethod
|
||||
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
@ -51,41 +65,49 @@ class LatentSubtract:
|
||||
|
||||
s2 = reshape_latent_to(s1.shape, s2)
|
||||
samples_out["samples"] = s1 - s2
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentMultiply:
|
||||
class LatentMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentMultiply",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples, multiplier):
|
||||
@classmethod
|
||||
def execute(cls, samples, multiplier) -> io.NodeOutput:
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = s1 * multiplier
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentInterpolate:
|
||||
class LatentInterpolate(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",),
|
||||
"samples2": ("LATENT",),
|
||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentInterpolate",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples1, samples2, ratio):
|
||||
@classmethod
|
||||
def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
@ -104,19 +126,26 @@ class LatentInterpolate:
|
||||
st = torch.nan_to_num(t / mt)
|
||||
|
||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentConcat:
|
||||
class LatentConcat(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentConcat",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples1, samples2, dim):
|
||||
@classmethod
|
||||
def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
@ -136,22 +165,27 @@ class LatentConcat:
|
||||
dim = -3
|
||||
|
||||
samples_out["samples"] = torch.cat(c, dim=dim)
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentCut:
|
||||
class LatentCut(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT",),
|
||||
"dim": (["x", "y", "t"], ),
|
||||
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentCut",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("dim", options=["x", "y", "t"]),
|
||||
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples, dim, index, amount):
|
||||
@classmethod
|
||||
def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
@ -171,19 +205,25 @@ class LatentCut:
|
||||
amount = min(-index, amount)
|
||||
|
||||
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentBatch:
|
||||
class LatentBatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentBatch",
|
||||
category="latent/batch",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "batch"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def batch(self, samples1, samples2):
|
||||
@classmethod
|
||||
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||
samples_out = samples1.copy()
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
@ -192,20 +232,25 @@ class LatentBatch:
|
||||
s = torch.cat((s1, s2), dim=0)
|
||||
samples_out["samples"] = s
|
||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentBatchSeedBehavior:
|
||||
class LatentBatchSeedBehavior(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentBatchSeedBehavior",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples, seed_behavior):
|
||||
@classmethod
|
||||
def execute(cls, samples, seed_behavior) -> io.NodeOutput:
|
||||
samples_out = samples.copy()
|
||||
latent = samples["samples"]
|
||||
if seed_behavior == "random":
|
||||
@ -215,41 +260,50 @@ class LatentBatchSeedBehavior:
|
||||
batch_number = samples_out.get("batch_index", [0])[0]
|
||||
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
||||
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentApplyOperation:
|
||||
class LatentApplyOperation(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"operation": ("LATENT_OPERATION",),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentApplyOperation",
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.LatentOperation.Input("operation"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, samples, operation):
|
||||
@classmethod
|
||||
def execute(cls, samples, operation) -> io.NodeOutput:
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = operation(latent=s1)
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentApplyOperationCFG:
|
||||
class LatentApplyOperationCFG(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"operation": ("LATENT_OPERATION",),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentApplyOperationCFG",
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.LatentOperation.Input("operation"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, operation):
|
||||
@classmethod
|
||||
def execute(cls, model, operation) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
|
||||
def pre_cfg_function(args):
|
||||
@ -261,21 +315,25 @@ class LatentApplyOperationCFG:
|
||||
return conds_out
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class LatentOperationTonemapReinhard:
|
||||
class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentOperationTonemapReinhard",
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.LatentOperation.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, multiplier):
|
||||
@classmethod
|
||||
def execute(cls, multiplier) -> io.NodeOutput:
|
||||
def tonemap_reinhard(latent, **kwargs):
|
||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||
normalized_latent = latent / latent_vector_magnitude
|
||||
@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard:
|
||||
new_magnitude *= top
|
||||
|
||||
return normalized_latent * new_magnitude
|
||||
return (tonemap_reinhard,)
|
||||
return io.NodeOutput(tonemap_reinhard)
|
||||
|
||||
class LatentOperationSharpen:
|
||||
class LatentOperationSharpen(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"sharpen_radius": ("INT", {
|
||||
"default": 9,
|
||||
"min": 1,
|
||||
"max": 31,
|
||||
"step": 1
|
||||
}),
|
||||
"sigma": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.1,
|
||||
"max": 10.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
"alpha": ("FLOAT", {
|
||||
"default": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 5.0,
|
||||
"step": 0.01
|
||||
}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentOperationSharpen",
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
|
||||
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
||||
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.LatentOperation.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced/operations"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, sharpen_radius, sigma, alpha):
|
||||
@classmethod
|
||||
def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
|
||||
def sharpen(latent, **kwargs):
|
||||
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
||||
normalized_latent = latent / luminance
|
||||
@ -340,19 +386,27 @@ class LatentOperationSharpen:
|
||||
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||
|
||||
return luminance * sharpened
|
||||
return (sharpen,)
|
||||
return io.NodeOutput(sharpen)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentAdd": LatentAdd,
|
||||
"LatentSubtract": LatentSubtract,
|
||||
"LatentMultiply": LatentMultiply,
|
||||
"LatentInterpolate": LatentInterpolate,
|
||||
"LatentConcat": LatentConcat,
|
||||
"LatentCut": LatentCut,
|
||||
"LatentBatch": LatentBatch,
|
||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
||||
"LatentApplyOperation": LatentApplyOperation,
|
||||
"LatentApplyOperationCFG": LatentApplyOperationCFG,
|
||||
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
|
||||
"LatentOperationSharpen": LatentOperationSharpen,
|
||||
}
|
||||
|
||||
class LatentExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LatentAdd,
|
||||
LatentSubtract,
|
||||
LatentMultiply,
|
||||
LatentInterpolate,
|
||||
LatentConcat,
|
||||
LatentCut,
|
||||
LatentBatch,
|
||||
LatentBatchSeedBehavior,
|
||||
LatentApplyOperation,
|
||||
LatentApplyOperationCFG,
|
||||
LatentOperationTonemapReinhard,
|
||||
LatentOperationSharpen,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LatentExtension:
|
||||
return LatentExtension()
|
||||
|
||||
@ -5,6 +5,8 @@ import folder_paths
|
||||
import os
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||
return output_sd
|
||||
|
||||
class LoraSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
class LoraSave(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoraSave",
|
||||
display_name="Extract and Save Lora",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
|
||||
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
|
||||
io.Boolean.Input("bias_diff", default=True),
|
||||
io.Model.Input(
|
||||
"model_diff",
|
||||
tooltip="The ModelSubtract output to be converted to a lora.",
|
||||
optional=True,
|
||||
),
|
||||
io.Clip.Input(
|
||||
"text_encoder_diff",
|
||||
tooltip="The CLIPSubtract output to be converted to a lora.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
||||
"lora_type": (tuple(LORA_TYPES.keys()),),
|
||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
||||
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
|
||||
if model_diff is None and text_encoder_diff is None:
|
||||
return {}
|
||||
return io.NodeOutput()
|
||||
|
||||
lora_type = LORA_TYPES.get(lora_type)
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
|
||||
output_sd = {}
|
||||
if model_diff is not None:
|
||||
@ -108,12 +118,16 @@ class LoraSave:
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||
return {}
|
||||
return io.NodeOutput()
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraSave": LoraSave
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoraSave": "Extract and Save Lora"
|
||||
}
|
||||
class LoraSaveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoraSave,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LoraSaveExtension:
|
||||
return LoraSaveExtension()
|
||||
|
||||
@ -1,24 +1,33 @@
|
||||
from typing_extensions import override
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class PatchModelAddDownscale:
|
||||
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||
|
||||
class PatchModelAddDownscale(io.ComfyNode):
|
||||
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
|
||||
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
||||
"downscale_method": (s.upscale_methods,),
|
||||
"upscale_method": (s.upscale_methods,),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PatchModelAddDownscale",
|
||||
display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("block_number", default=3, min=1, max=32, step=1),
|
||||
io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
|
||||
io.Boolean.Input("downscale_after_skip", default=True),
|
||||
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
|
||||
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||
@classmethod
|
||||
def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
@ -41,13 +50,21 @@ class PatchModelAddDownscale:
|
||||
else:
|
||||
m.set_model_input_block_patch(input_block_patch)
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PatchModelAddDownscale": PatchModelAddDownscale,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Sampling
|
||||
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
|
||||
"PatchModelAddDownscale": "",
|
||||
}
|
||||
|
||||
class ModelDownscaleExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
PatchModelAddDownscale,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ModelDownscaleExtension:
|
||||
return ModelDownscaleExtension()
|
||||
|
||||
@ -61,7 +61,6 @@ messages_control.disable = [
|
||||
# next warnings should be fixed in future
|
||||
"bad-classmethod-argument", # Class method should have 'cls' as first argument
|
||||
"wrong-import-order", # Standard imports should be placed before third party imports
|
||||
"logging-fstring-interpolation", # Use lazy % formatting in logging functions
|
||||
"ungrouped-imports",
|
||||
"unnecessary-pass",
|
||||
"unnecessary-lambda-assignment",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user