mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-09 21:12:36 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
f6339e8115
@ -1,7 +1,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from typing_extensions import TypedDict, Literal, NotRequired, Dict
|
from typing_extensions import TypedDict, Literal, NotRequired
|
||||||
|
|
||||||
|
|
||||||
class FileOutput(TypedDict, total=False):
|
class FileOutput(TypedDict, total=False):
|
||||||
@ -21,4 +21,4 @@ class Output(TypedDict, total=False):
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class V1QueuePromptResponse:
|
class V1QueuePromptResponse:
|
||||||
urls: List[str]
|
urls: List[str]
|
||||||
outputs: Dict[str, Output]
|
outputs: dict[str, Output]
|
||||||
|
|||||||
@ -500,19 +500,21 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data,
|
|||||||
|
|
||||||
logging.error("An error occurred while executing a workflow", exc_info=ex)
|
logging.error("An error occurred while executing a workflow", exc_info=ex)
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
tips = ""
|
||||||
|
|
||||||
|
if isinstance(ex, model_management.OOM_EXCEPTION):
|
||||||
|
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||||
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
|
model_management.unload_all_models()
|
||||||
|
|
||||||
error_details: RecursiveExecutionErrorDetails = {
|
error_details: RecursiveExecutionErrorDetails = {
|
||||||
"node_id": real_node_id,
|
"node_id": real_node_id,
|
||||||
"exception_message": str(ex),
|
"exception_message": "{}\n{}".format(ex, tips),
|
||||||
"exception_type": exception_type,
|
"exception_type": exception_type,
|
||||||
"traceback": traceback.format_tb(tb),
|
"traceback": traceback.format_tb(tb),
|
||||||
"current_inputs": input_data_formatted
|
"current_inputs": input_data_formatted
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(ex, model_management.OOM_EXCEPTION):
|
|
||||||
logging.error("Got an OOM, unloading all loaded models.")
|
|
||||||
model_management.unload_all_models()
|
|
||||||
|
|
||||||
if should_panic_on_exception(ex, args.panic_when):
|
if should_panic_on_exception(ex, args.panic_when):
|
||||||
logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
||||||
|
|
||||||
|
|||||||
@ -82,7 +82,12 @@ def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptS
|
|||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
logger.debug("Prompt executed in {:.2f} seconds".format(execution_time))
|
# Log Time in a more readable way after 10 minutes
|
||||||
|
if execution_time > 600:
|
||||||
|
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
|
||||||
|
logger.info(f"Prompt executed in {execution_time}")
|
||||||
|
else:
|
||||||
|
logger.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||||
|
|
||||||
flags = q.get_flags()
|
flags = q.get_flags()
|
||||||
free_memory = flags.get("free_memory", False)
|
free_memory = flags.get("free_memory", False)
|
||||||
|
|||||||
@ -726,38 +726,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
|||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
|
||||||
t_fn = lambda sigma: sigma.log().neg()
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
# Euler method
|
# Denoising step
|
||||||
d = to_d(x, sigmas[i], denoised)
|
x = denoised
|
||||||
dt = sigmas[i + 1] - sigmas[i]
|
|
||||||
x = x + d * dt
|
|
||||||
else:
|
else:
|
||||||
# DPM-Solver++
|
# DPM-Solver++
|
||||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
h = t_next - t
|
h = lambda_t - lambda_s
|
||||||
s = t + h * r
|
lambda_s_1 = lambda_s + r * h
|
||||||
fac = 1 / (2 * r)
|
fac = 1 / (2 * r)
|
||||||
|
|
||||||
|
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||||
|
|
||||||
|
alpha_s = sigmas[i] * lambda_s.exp()
|
||||||
|
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||||
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
|
|
||||||
# Step 1
|
# Step 1
|
||||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
|
||||||
s_ = t_fn(sd)
|
lambda_s_1_ = sd.log().neg()
|
||||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
h_ = lambda_s_1_ - lambda_s
|
||||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
|
||||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
if eta > 0 and s_noise > 0:
|
||||||
|
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
|
||||||
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
|
||||||
t_next_ = t_fn(sd)
|
lambda_t_ = sd.log().neg()
|
||||||
|
h_ = lambda_t_ - lambda_s
|
||||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
|
||||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
if eta > 0 and s_noise > 0:
|
||||||
|
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -126,6 +126,8 @@ class ControlNetFlux(Flux):
|
|||||||
|
|
||||||
if y is None:
|
if y is None:
|
||||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||||
|
else:
|
||||||
|
y = y[:, :self.params.vec_in_dim]
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
|
|||||||
@ -117,7 +117,7 @@ class Modulation(nn.Module):
|
|||||||
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||||
if modulation_dims is None:
|
if modulation_dims is None:
|
||||||
if m_add is not None:
|
if m_add is not None:
|
||||||
return tensor * m_mult + m_add
|
return torch.addcmul(m_add, tensor, m_mult)
|
||||||
else:
|
else:
|
||||||
return tensor * m_mult
|
return tensor * m_mult
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -31,7 +31,7 @@ def dynamic_slice(
|
|||||||
starts: List[int],
|
starts: List[int],
|
||||||
sizes: List[int],
|
sizes: List[int],
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
|
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
|
||||||
return x[slicing]
|
return x[slicing]
|
||||||
|
|
||||||
class AttnChunk(NamedTuple):
|
class AttnChunk(NamedTuple):
|
||||||
|
|||||||
@ -1067,6 +1067,8 @@ class CosmosPredict2(BaseModel):
|
|||||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
if denoise_mask is None:
|
if denoise_mask is None:
|
||||||
return timestep
|
return timestep
|
||||||
|
if denoise_mask.ndim <= 4:
|
||||||
|
return timestep
|
||||||
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
|
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
|
||||||
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
|
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
|
||||||
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
|
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
|
||||||
|
|||||||
@ -11,6 +11,43 @@ from comfy_config.types import (
|
|||||||
PyProjectSettings
|
PyProjectSettings
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_and_extract_os_classifiers(classifiers: list) -> list:
|
||||||
|
os_classifiers = [c for c in classifiers if c.startswith("Operating System :: ")]
|
||||||
|
if not os_classifiers:
|
||||||
|
return []
|
||||||
|
|
||||||
|
os_values = [c[len("Operating System :: ") :] for c in os_classifiers]
|
||||||
|
valid_os_prefixes = {"Microsoft", "POSIX", "MacOS", "OS Independent"}
|
||||||
|
|
||||||
|
for os_value in os_values:
|
||||||
|
if not any(os_value.startswith(prefix) for prefix in valid_os_prefixes):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return os_values
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_extract_accelerator_classifiers(classifiers: list) -> list:
|
||||||
|
accelerator_classifiers = [c for c in classifiers if c.startswith("Environment ::")]
|
||||||
|
if not accelerator_classifiers:
|
||||||
|
return []
|
||||||
|
|
||||||
|
accelerator_values = [c[len("Environment :: ") :] for c in accelerator_classifiers]
|
||||||
|
|
||||||
|
valid_accelerators = {
|
||||||
|
"GPU :: NVIDIA CUDA",
|
||||||
|
"GPU :: AMD ROCm",
|
||||||
|
"GPU :: Intel Arc",
|
||||||
|
"NPU :: Huawei Ascend",
|
||||||
|
"GPU :: Apple Metal",
|
||||||
|
}
|
||||||
|
|
||||||
|
for accelerator_value in accelerator_values:
|
||||||
|
if accelerator_value not in valid_accelerators:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return accelerator_values
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Extract configuration from a custom node directory's pyproject.toml file or a Python file.
|
Extract configuration from a custom node directory's pyproject.toml file or a Python file.
|
||||||
|
|
||||||
@ -78,6 +115,24 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]:
|
|||||||
tool_data = raw_settings.tool
|
tool_data = raw_settings.tool
|
||||||
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
|
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
|
||||||
|
|
||||||
|
dependencies = project_data.get("dependencies", [])
|
||||||
|
supported_comfyui_frontend_version = ""
|
||||||
|
for dep in dependencies:
|
||||||
|
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
|
||||||
|
supported_comfyui_frontend_version = dep.removeprefix("comfyui-frontend-package")
|
||||||
|
break
|
||||||
|
|
||||||
|
supported_comfyui_version = comfy_data.get("requires-comfyui", "")
|
||||||
|
|
||||||
|
classifiers = project_data.get('classifiers', [])
|
||||||
|
supported_os = validate_and_extract_os_classifiers(classifiers)
|
||||||
|
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)
|
||||||
|
|
||||||
|
project_data['supported_os'] = supported_os
|
||||||
|
project_data['supported_accelerators'] = supported_accelerators
|
||||||
|
project_data['supported_comfyui_frontend_version'] = supported_comfyui_frontend_version
|
||||||
|
project_data['supported_comfyui_version'] = supported_comfyui_version
|
||||||
|
|
||||||
return PyProjectConfig(project=project_data, tool_comfy=comfy_data)
|
return PyProjectConfig(project=project_data, tool_comfy=comfy_data)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -51,7 +51,7 @@ class ComfyConfig(BaseModel):
|
|||||||
models: List[Model] = Field(default_factory=list, alias="Models")
|
models: List[Model] = Field(default_factory=list, alias="Models")
|
||||||
includes: List[str] = Field(default_factory=list)
|
includes: List[str] = Field(default_factory=list)
|
||||||
web: Optional[str] = None
|
web: Optional[str] = None
|
||||||
|
banner_url: str = ""
|
||||||
|
|
||||||
class License(BaseModel):
|
class License(BaseModel):
|
||||||
file: str = ""
|
file: str = ""
|
||||||
@ -66,6 +66,10 @@ class ProjectConfig(BaseModel):
|
|||||||
dependencies: List[str] = Field(default_factory=list)
|
dependencies: List[str] = Field(default_factory=list)
|
||||||
license: License = Field(default_factory=License)
|
license: License = Field(default_factory=License)
|
||||||
urls: URLs = Field(default_factory=URLs)
|
urls: URLs = Field(default_factory=URLs)
|
||||||
|
supported_os: List[str] = Field(default_factory=list)
|
||||||
|
supported_accelerators: List[str] = Field(default_factory=list)
|
||||||
|
supported_comfyui_version: str = ""
|
||||||
|
supported_comfyui_frontend_version: str = ""
|
||||||
|
|
||||||
@field_validator('license', mode='before')
|
@field_validator('license', mode='before')
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from comfy.execution_context import current_execution_context
|
|||||||
from comfy.nodes.base_nodes import ImageScale
|
from comfy.nodes.base_nodes import ImageScale
|
||||||
from comfy.nodes.common import MAX_RESOLUTION
|
from comfy.nodes.common import MAX_RESOLUTION
|
||||||
from comfy.nodes.package_typing import CustomNode
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
from comfy.utils import common_upscale
|
||||||
from comfy_extras.constants.resolutions import RESOLUTION_MAP, RESOLUTION_NAMES, SD_RESOLUTIONS
|
from comfy_extras.constants.resolutions import RESOLUTION_MAP, RESOLUTION_NAMES, SD_RESOLUTIONS
|
||||||
|
|
||||||
|
|
||||||
@ -450,86 +451,235 @@ class ImageStitch:
|
|||||||
"image1": ("IMAGE",),
|
"image1": ("IMAGE",),
|
||||||
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||||
"spacing_width": ("INT", {"default": 0, "min": 0, "max": 1024, "step": 2},),
|
"spacing_width": (
|
||||||
"spacing_color": (["white", "black", "red", "green", "blue"], {"default": "white"},),
|
"INT",
|
||||||
|
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||||
|
),
|
||||||
|
"spacing_color": (
|
||||||
|
["white", "black", "red", "green", "blue"],
|
||||||
|
{"default": "white"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image2": ("IMAGE",),
|
||||||
},
|
},
|
||||||
"optional": {"image2": ("IMAGE",), },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "stitch"
|
FUNCTION = "stitch"
|
||||||
CATEGORY = "image/transform"
|
CATEGORY = "image/transform"
|
||||||
DESCRIPTION = "Stitches image2 to image1 in the specified direction."
|
DESCRIPTION = """
|
||||||
|
Stitches image2 to image1 in the specified direction.
|
||||||
|
If image2 is not provided, returns image1 unchanged.
|
||||||
|
Optional spacing can be added between images.
|
||||||
|
"""
|
||||||
|
|
||||||
def stitch(self, image1, direction, match_image_size, spacing_width, spacing_color, image2=None):
|
def stitch(
|
||||||
|
self,
|
||||||
|
image1,
|
||||||
|
direction,
|
||||||
|
match_image_size,
|
||||||
|
spacing_width,
|
||||||
|
spacing_color,
|
||||||
|
image2=None,
|
||||||
|
):
|
||||||
if image2 is None:
|
if image2 is None:
|
||||||
return (image1,)
|
return (image1,)
|
||||||
|
|
||||||
|
# Handle batch size differences
|
||||||
if image1.shape[0] != image2.shape[0]:
|
if image1.shape[0] != image2.shape[0]:
|
||||||
max_batch = max(image1.shape[0], image2.shape[0])
|
max_batch = max(image1.shape[0], image2.shape[0])
|
||||||
if image1.shape[0] < max_batch:
|
if image1.shape[0] < max_batch:
|
||||||
image1 = torch.cat([image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)])
|
image1 = torch.cat(
|
||||||
|
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
|
||||||
|
)
|
||||||
if image2.shape[0] < max_batch:
|
if image2.shape[0] < max_batch:
|
||||||
image2 = torch.cat([image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)])
|
image2 = torch.cat(
|
||||||
|
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Match image sizes if requested
|
||||||
if match_image_size:
|
if match_image_size:
|
||||||
h1, w1 = image1.shape[1:3]
|
h1, w1 = image1.shape[1:3]
|
||||||
h2, w2 = image2.shape[1:3]
|
h2, w2 = image2.shape[1:3]
|
||||||
aspect_ratio = w2 / h2
|
aspect_ratio = w2 / h2
|
||||||
|
|
||||||
if direction in ["left", "right"]:
|
if direction in ["left", "right"]:
|
||||||
target_h, target_w = h1, int(h1 * aspect_ratio)
|
target_h, target_w = h1, int(h1 * aspect_ratio)
|
||||||
else:
|
else: # up, down
|
||||||
target_w, target_h = w1, int(w1 / aspect_ratio)
|
target_w, target_h = w1, int(w1 / aspect_ratio)
|
||||||
image2 = utils.common_upscale(image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled").movedim(1, -1)
|
|
||||||
else:
|
image2 = common_upscale(
|
||||||
|
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
|
||||||
|
).movedim(1, -1)
|
||||||
|
|
||||||
|
color_map = {
|
||||||
|
"white": 1.0,
|
||||||
|
"black": 0.0,
|
||||||
|
"red": (1.0, 0.0, 0.0),
|
||||||
|
"green": (0.0, 1.0, 0.0),
|
||||||
|
"blue": (0.0, 0.0, 1.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
color_val = color_map[spacing_color]
|
||||||
|
|
||||||
|
# When not matching sizes, pad to align non-concat dimensions
|
||||||
|
if not match_image_size:
|
||||||
h1, w1 = image1.shape[1:3]
|
h1, w1 = image1.shape[1:3]
|
||||||
h2, w2 = image2.shape[1:3]
|
h2, w2 = image2.shape[1:3]
|
||||||
|
pad_value = 0.0
|
||||||
|
if not isinstance(color_val, tuple):
|
||||||
|
pad_value = color_val
|
||||||
|
|
||||||
if direction in ["left", "right"]:
|
if direction in ["left", "right"]:
|
||||||
|
# For horizontal concat, pad heights to match
|
||||||
if h1 != h2:
|
if h1 != h2:
|
||||||
target_h = max(h1, h2)
|
target_h = max(h1, h2)
|
||||||
if h1 < target_h:
|
if h1 < target_h:
|
||||||
pad_h = target_h - h1
|
pad_h = target_h - h1
|
||||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||||
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
|
||||||
if h2 < target_h:
|
if h2 < target_h:
|
||||||
pad_h = target_h - h2
|
pad_h = target_h - h2
|
||||||
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
||||||
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
|
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
|
||||||
else:
|
else: # up, down
|
||||||
|
# For vertical concat, pad widths to match
|
||||||
if w1 != w2:
|
if w1 != w2:
|
||||||
target_w = max(w1, w2)
|
target_w = max(w1, w2)
|
||||||
if w1 < target_w:
|
if w1 < target_w:
|
||||||
pad_w = target_w - w1
|
pad_w = target_w - w1
|
||||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||||
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
|
||||||
if w2 < target_w:
|
if w2 < target_w:
|
||||||
pad_w = target_w - w2
|
pad_w = target_w - w2
|
||||||
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
||||||
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
|
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
|
||||||
|
|
||||||
images_to_stitch = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
# Ensure same number of channels
|
||||||
|
if image1.shape[-1] != image2.shape[-1]:
|
||||||
|
max_channels = max(image1.shape[-1], image2.shape[-1])
|
||||||
|
if image1.shape[-1] < max_channels:
|
||||||
|
image1 = torch.cat(
|
||||||
|
[
|
||||||
|
image1,
|
||||||
|
torch.ones(
|
||||||
|
*image1.shape[:-1],
|
||||||
|
max_channels - image1.shape[-1],
|
||||||
|
device=image1.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
if image2.shape[-1] < max_channels:
|
||||||
|
image2 = torch.cat(
|
||||||
|
[
|
||||||
|
image2,
|
||||||
|
torch.ones(
|
||||||
|
*image2.shape[:-1],
|
||||||
|
max_channels - image2.shape[-1],
|
||||||
|
device=image2.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add spacing if specified
|
||||||
if spacing_width > 0:
|
if spacing_width > 0:
|
||||||
color_map = {"white": 1.0, "black": 0.0, "red": (1.0, 0.0, 0.0), "green": (0.0, 1.0, 0.0), "blue": (0.0, 0.0, 1.0), }
|
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
|
||||||
color_val = color_map[spacing_color]
|
|
||||||
if direction in ["left", "right"]:
|
if direction in ["left", "right"]:
|
||||||
spacing_shape = (image1.shape[0], max(image1.shape[1], image2.shape[1]), spacing_width, image1.shape[-1],)
|
spacing_shape = (
|
||||||
|
image1.shape[0],
|
||||||
|
max(image1.shape[1], image2.shape[1]),
|
||||||
|
spacing_width,
|
||||||
|
image1.shape[-1],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
spacing_shape = (image1.shape[0], spacing_width, max(image1.shape[2], image2.shape[2]), image1.shape[-1],)
|
spacing_shape = (
|
||||||
|
image1.shape[0],
|
||||||
|
spacing_width,
|
||||||
|
max(image1.shape[2], image2.shape[2]),
|
||||||
|
image1.shape[-1],
|
||||||
|
)
|
||||||
|
|
||||||
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
||||||
if isinstance(color_val, tuple):
|
if isinstance(color_val, tuple):
|
||||||
for i, c in enumerate(color_val):
|
for i, c in enumerate(color_val):
|
||||||
if i < spacing.shape[-1]:
|
if i < spacing.shape[-1]:
|
||||||
spacing[..., i] = c
|
spacing[..., i] = c
|
||||||
if spacing.shape[-1] == 4:
|
if spacing.shape[-1] == 4: # Add alpha
|
||||||
spacing[..., 3] = 1.0
|
spacing[..., 3] = 1.0
|
||||||
else:
|
else:
|
||||||
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
||||||
if spacing.shape[-1] == 4:
|
if spacing.shape[-1] == 4:
|
||||||
spacing[..., 3] = 1.0
|
spacing[..., 3] = 1.0
|
||||||
images_to_stitch.insert(1, spacing)
|
|
||||||
|
# Concatenate images
|
||||||
|
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
||||||
|
if spacing_width > 0:
|
||||||
|
images.insert(1, spacing)
|
||||||
|
|
||||||
concat_dim = 2 if direction in ["left", "right"] else 1
|
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||||
return (torch.cat(images_to_stitch, dim=concat_dim),)
|
return (torch.cat(images, dim=concat_dim),)
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeAndPadImage:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"target_width": ("INT", {
|
||||||
|
"default": 512,
|
||||||
|
"min": 1,
|
||||||
|
"max": MAX_RESOLUTION,
|
||||||
|
"step": 1
|
||||||
|
}),
|
||||||
|
"target_height": ("INT", {
|
||||||
|
"default": 512,
|
||||||
|
"min": 1,
|
||||||
|
"max": MAX_RESOLUTION,
|
||||||
|
"step": 1
|
||||||
|
}),
|
||||||
|
"padding_color": (["white", "black"],),
|
||||||
|
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "resize_and_pad"
|
||||||
|
CATEGORY = "image/transform"
|
||||||
|
|
||||||
|
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
|
||||||
|
batch_size, orig_height, orig_width, channels = image.shape
|
||||||
|
|
||||||
|
scale_w = target_width / orig_width
|
||||||
|
scale_h = target_height / orig_height
|
||||||
|
scale = min(scale_w, scale_h)
|
||||||
|
|
||||||
|
new_width = int(orig_width * scale)
|
||||||
|
new_height = int(orig_height * scale)
|
||||||
|
|
||||||
|
image_permuted = image.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
resized = common_upscale(image_permuted, new_width, new_height, interpolation, "disabled")
|
||||||
|
|
||||||
|
pad_value = 0.0 if padding_color == "black" else 1.0
|
||||||
|
padded = torch.full(
|
||||||
|
(batch_size, channels, target_height, target_width),
|
||||||
|
pad_value,
|
||||||
|
dtype=image.dtype,
|
||||||
|
device=image.device
|
||||||
|
)
|
||||||
|
|
||||||
|
y_offset = (target_height - new_height) // 2
|
||||||
|
x_offset = (target_width - new_width) // 2
|
||||||
|
|
||||||
|
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
|
||||||
|
|
||||||
|
output = padded.permute(0, 2, 3, 1)
|
||||||
|
return (output,)
|
||||||
|
|
||||||
|
|
||||||
class SaveSVGNode:
|
class SaveSVGNode:
|
||||||
@ -599,6 +749,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SaveSVGNode": SaveSVGNode,
|
"SaveSVGNode": SaveSVGNode,
|
||||||
"ImageStitch": ImageStitch,
|
"ImageStitch": ImageStitch,
|
||||||
"GetImageSize": GetImageSize,
|
"GetImageSize": GetImageSize,
|
||||||
|
"ResizeAndPadImage": ResizeAndPadImage,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
|||||||
@ -276,6 +276,52 @@ class ModelMergeWAN2_1(nodes_model_merging.ModelMergeBlocks):
|
|||||||
|
|
||||||
return {"required": arg_dict}
|
return {"required": arg_dict}
|
||||||
|
|
||||||
|
class ModelMergeCosmosPredict2_2B(nodes_model_merging.ModelMergeBlocks):
|
||||||
|
CATEGORY = "advanced/model_merging/model_specific"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
arg_dict = { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",)}
|
||||||
|
|
||||||
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
|
||||||
|
arg_dict["pos_embedder."] = argument
|
||||||
|
arg_dict["x_embedder."] = argument
|
||||||
|
arg_dict["t_embedder."] = argument
|
||||||
|
arg_dict["t_embedding_norm."] = argument
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(28):
|
||||||
|
arg_dict["blocks.{}.".format(i)] = argument
|
||||||
|
|
||||||
|
arg_dict["final_layer."] = argument
|
||||||
|
|
||||||
|
return {"required": arg_dict}
|
||||||
|
|
||||||
|
class ModelMergeCosmosPredict2_14B(nodes_model_merging.ModelMergeBlocks):
|
||||||
|
CATEGORY = "advanced/model_merging/model_specific"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
arg_dict = { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",)}
|
||||||
|
|
||||||
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
|
||||||
|
arg_dict["pos_embedder."] = argument
|
||||||
|
arg_dict["x_embedder."] = argument
|
||||||
|
arg_dict["t_embedder."] = argument
|
||||||
|
arg_dict["t_embedding_norm."] = argument
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(36):
|
||||||
|
arg_dict["blocks.{}.".format(i)] = argument
|
||||||
|
|
||||||
|
arg_dict["final_layer."] = argument
|
||||||
|
|
||||||
|
return {"required": arg_dict}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSD1": ModelMergeSD1,
|
"ModelMergeSD1": ModelMergeSD1,
|
||||||
"ModelMergeSD2": ModelMergeSD1, # SD1 and SD2 have the same blocks
|
"ModelMergeSD2": ModelMergeSD1, # SD1 and SD2 have the same blocks
|
||||||
@ -289,4 +335,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||||
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||||
|
"ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B,
|
||||||
|
"ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user