mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
671a769dc6
@ -35,7 +35,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
import inspect
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
@ -71,7 +71,6 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
@torch.compiler.disable()
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
|
||||
@ -126,12 +126,12 @@ class QuantizedTensor(torch.Tensor):
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata.contiguous()
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
|
||||
def __repr__(self):
|
||||
layout_name = self._layout_type.__name__
|
||||
layout_name = self._layout_type
|
||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||
|
||||
@ -179,7 +179,7 @@ class QuantizedTensor(torch.Tensor):
|
||||
attr_name = f"_layout_param_{key}"
|
||||
layout_params[key] = inner_tensors[attr_name]
|
||||
|
||||
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
|
||||
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||
@ -411,13 +411,17 @@ def fp8_linear(func, args, kwargs):
|
||||
|
||||
try:
|
||||
output = torch._scaled_mm(
|
||||
plain_input.reshape(-1, input_shape[2]),
|
||||
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
||||
weight_t,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
|
||||
if not tensor_2d:
|
||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||
|
||||
@ -442,6 +446,59 @@ def fp8_linear(func, args, kwargs):
|
||||
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
|
||||
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
plain_input.contiguous(),
|
||||
plain_weight,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||
def fp8_addmm(func, args, kwargs):
|
||||
input_tensor = args[1]
|
||||
weight = args[2]
|
||||
bias = args[0]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
if isinstance(args[2], QuantizedTensor):
|
||||
a[2] = args[2].dequantize()
|
||||
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||
def fp8_mm(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||
|
||||
@ -1,73 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import mimetypes
|
||||
from typing import Union
|
||||
from server import PromptServer
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
async def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
response: The response to validate and cast.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
ValueError: If the response is not valid.
|
||||
"""
|
||||
# validate raw JSON response
|
||||
data = response.data
|
||||
if not data or len(data) == 0:
|
||||
raise ValueError("No images returned from API endpoint")
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
|
||||
# Process each image in the data array
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||||
for img_data in data:
|
||||
img_bytes: bytes
|
||||
if img_data.b64_json:
|
||||
img_bytes = base64.b64decode(img_data.b64_json)
|
||||
elif img_data.url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
||||
async with session.get(img_data.url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError("Failed to download generated image")
|
||||
img_bytes = await resp.read()
|
||||
else:
|
||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
||||
|
||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
file_content = f.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
|
||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
"""Converts a text file to a data URI."""
|
||||
base64_string = text_filepath_to_base64_string(filepath)
|
||||
mime_type, _ = mimetypes.guess_type(filepath)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
File diff suppressed because it is too large
Load Diff
@ -7,24 +7,23 @@ from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import Optional, TypeVar
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apis import pika_defs
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||
from comfy_api_nodes.util import (
|
||||
validate_string,
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
ApiEndpoint,
|
||||
EmptyRequest,
|
||||
HttpMethod,
|
||||
PollingOperation,
|
||||
SynchronousOperation,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||
@ -40,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||
|
||||
|
||||
async def execute_task(
|
||||
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
task_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
) -> IO.NodeOutput:
|
||||
task_id = (await initial_operation.execute()).video_id
|
||||
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||
response_model=pika_defs.PikaVideoResponse,
|
||||
),
|
||||
completed_statuses=["finished"],
|
||||
failed_statuses=["failed", "cancelled"],
|
||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
|
||||
node_id=node_id,
|
||||
estimated_duration=60,
|
||||
max_poll_attempts=240,
|
||||
).execute()
|
||||
)
|
||||
if not final_response.url:
|
||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||
logging.error(error_msg)
|
||||
@ -124,23 +113,15 @@ class PikaImageToVideo(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaTextToVideoNode(IO.ComfyNode):
|
||||
@ -183,18 +164,11 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
) -> IO.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
@ -202,10 +176,9 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaScenes(IO.ComfyNode):
|
||||
@ -309,24 +282,16 @@ class PikaScenes(IO.ComfyNode):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKASCENES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikAdditionsNode(IO.ComfyNode):
|
||||
@ -383,24 +348,16 @@ class PikAdditionsNode(IO.ComfyNode):
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKADDITIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaSwapsNode(IO.ComfyNode):
|
||||
@ -472,23 +429,15 @@ class PikaSwapsNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKASWAPS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaffectsNode(IO.ComfyNode):
|
||||
@ -528,18 +477,11 @@ class PikaffectsNode(IO.ComfyNode):
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFFECTS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
pikaffect=pikaffect,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
@ -547,9 +489,8 @@ class PikaffectsNode(IO.ComfyNode):
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
@ -592,18 +533,11 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFRAMES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
@ -612,9 +546,8 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaApiNodesExtension(ComfyExtension):
|
||||
|
||||
@ -225,21 +225,20 @@ async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] =
|
||||
|
||||
|
||||
async def download_files(url_list, task_uuid):
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
|
||||
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in url_list.list:
|
||||
url = i.url
|
||||
file_name = i.name
|
||||
file_path = os.path.join(save_path, file_name)
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
async with session.get(i.url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
||||
|
||||
@ -18,6 +18,8 @@ from .conversions import (
|
||||
tensor_to_base64_string,
|
||||
tensor_to_bytesio,
|
||||
tensor_to_pil,
|
||||
text_filepath_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
trim_video,
|
||||
video_to_base64_string,
|
||||
)
|
||||
@ -75,6 +77,8 @@ __all__ = [
|
||||
"tensor_to_base64_string",
|
||||
"tensor_to_bytesio",
|
||||
"tensor_to_pil",
|
||||
"text_filepath_to_base64_string",
|
||||
"text_filepath_to_data_uri",
|
||||
"trim_video",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import mimetypes
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
@ -12,7 +13,7 @@ from PIL import Image
|
||||
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import Input, InputImpl
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
|
||||
from ._helpers import mimetype_to_extension
|
||||
|
||||
@ -451,3 +452,19 @@ def resize_mask_to_image(
|
||||
if not allow_gradient:
|
||||
mask = (mask > 0.5).float()
|
||||
return mask
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
file_content = f.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
|
||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
"""Converts a text file to a data URI."""
|
||||
base64_string = text_filepath_to_base64_string(filepath)
|
||||
mime_type, _ = mimetypes.guess_type(filepath)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
@ -53,7 +53,7 @@ class Unhashable:
|
||||
def to_hashable(obj):
|
||||
# So that we don't infinitely recurse since frozenset and tuples
|
||||
# are Sequences.
|
||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||
|
||||
@ -2,6 +2,9 @@ import comfy.utils
|
||||
import folder_paths
|
||||
import torch
|
||||
import logging
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def load_hypernetwork_patch(path, strength):
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
|
||||
|
||||
return hypernetwork_patch(out, strength)
|
||||
|
||||
class HypernetworkLoader:
|
||||
class HypernetworkLoader(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_hypernetwork"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="HypernetworkLoader",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||
@classmethod
|
||||
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
|
||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||
model_hypernetwork = model.clone()
|
||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||
if patch is not None:
|
||||
model_hypernetwork.set_model_attn1_patch(patch)
|
||||
model_hypernetwork.set_model_attn2_patch(patch)
|
||||
return (model_hypernetwork,)
|
||||
return IO.NodeOutput(model_hypernetwork)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HypernetworkLoader": HypernetworkLoader
|
||||
}
|
||||
load_hypernetwork = execute # TODO: remove
|
||||
|
||||
|
||||
class HyperNetworkExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
HypernetworkLoader,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> HyperNetworkExtension:
|
||||
return HyperNetworkExtension()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.28.8
|
||||
comfyui-workflow-templates==0.2.4
|
||||
comfyui-embedded-docs==0.3.0
|
||||
comfyui-embedded-docs==0.3.1
|
||||
comfyui_manager==4.0.3b1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
Loading…
Reference in New Issue
Block a user