mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 19:42:59 +08:00
Merge branch 'comfyanonymous:master' into feature/custom-node-paths-cli-args
This commit is contained in:
commit
d3d75415ff
@ -66,8 +66,10 @@ if branch is None:
|
|||||||
try:
|
try:
|
||||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||||
except:
|
except:
|
||||||
print("pulling.") # noqa: T201
|
print("fetching.") # noqa: T201
|
||||||
pull(repo)
|
for remote in repo.remotes:
|
||||||
|
if remote.name == "origin":
|
||||||
|
remote.fetch()
|
||||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||||
repo.checkout(ref)
|
repo.checkout(ref)
|
||||||
branch = repo.lookup_branch('master')
|
branch = repo.lookup_branch('master')
|
||||||
@ -149,3 +151,4 @@ try:
|
|||||||
shutil.copy(stable_update_script, stable_update_script_to)
|
shutil.copy(stable_update_script, stable_update_script_to)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,2 @@
|
|||||||
# Admins
|
# Admins
|
||||||
* @comfyanonymous
|
* @comfyanonymous @kosinkadink @guill
|
||||||
* @kosinkadink
|
|
||||||
|
|||||||
27
README.md
27
README.md
@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||||
|
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
|
||||||
- Audio Models
|
- Audio Models
|
||||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
@ -319,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
|||||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||||
2. Launch ComfyUI by running `python main.py`
|
2. Launch ComfyUI by running `python main.py`
|
||||||
|
|
||||||
|
|
||||||
|
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
|
||||||
|
|
||||||
|
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
1. Install the manager dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r manager_requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
|
||||||
|
```bash
|
||||||
|
python main.py --enable-manager
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command Line Options
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||||
|
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||||
|
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||||
|
|
||||||
|
|
||||||
# Running
|
# Running
|
||||||
|
|
||||||
```python main.py```
|
```python main.py```
|
||||||
|
|||||||
@ -122,6 +122,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
|
|||||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||||
|
manager_group = parser.add_mutually_exclusive_group()
|
||||||
|
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||||
|
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||||
|
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
@ -169,6 +175,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
|
|||||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||||
|
|
||||||
|
|
||||||
# The default built-in provider hosted under web/
|
# The default built-in provider hosted under web/
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
|
|
||||||
|
|||||||
@ -51,26 +51,36 @@ class ContextHandlerABC(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class IndexListContextWindow(ContextWindowABC):
|
class IndexListContextWindow(ContextWindowABC):
|
||||||
def __init__(self, index_list: list[int], dim: int=0):
|
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||||
self.index_list = index_list
|
self.index_list = index_list
|
||||||
self.context_length = len(index_list)
|
self.context_length = len(index_list)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
self.total_frames = total_frames
|
||||||
|
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||||
|
|
||||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = self.dim
|
dim = self.dim
|
||||||
if dim == 0 and full.shape[dim] == 1:
|
if dim == 0 and full.shape[dim] == 1:
|
||||||
return full
|
return full
|
||||||
idx = [slice(None)] * dim + [self.index_list]
|
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||||
return full[idx].to(device)
|
window = full[idx]
|
||||||
|
if retain_index_list:
|
||||||
|
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||||
|
window[idx] = full[idx]
|
||||||
|
return window.to(device)
|
||||||
|
|
||||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = self.dim
|
dim = self.dim
|
||||||
idx = [slice(None)] * dim + [self.index_list]
|
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||||
full[idx] += to_add
|
full[idx] += to_add
|
||||||
return full
|
return full
|
||||||
|
|
||||||
|
def get_region_index(self, num_regions: int) -> int:
|
||||||
|
region_idx = int(self.center_ratio * num_regions)
|
||||||
|
return min(max(region_idx, 0), num_regions - 1)
|
||||||
|
|
||||||
|
|
||||||
class IndexListCallbacks:
|
class IndexListCallbacks:
|
||||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||||
@ -94,7 +104,8 @@ class ContextFuseMethod:
|
|||||||
|
|
||||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||||
class IndexListContextHandler(ContextHandlerABC):
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||||
|
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||||
self.context_schedule = context_schedule
|
self.context_schedule = context_schedule
|
||||||
self.fuse_method = fuse_method
|
self.fuse_method = fuse_method
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
self.closed_loop = closed_loop
|
self.closed_loop = closed_loop
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self._step = 0
|
self._step = 0
|
||||||
|
self.freenoise = freenoise
|
||||||
|
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||||
|
self.split_conds_to_windows = split_conds_to_windows
|
||||||
|
|
||||||
self.callbacks = {}
|
self.callbacks = {}
|
||||||
|
|
||||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||||
if x_in.size(self.dim) > self.context_length:
|
if x_in.size(self.dim) > self.context_length:
|
||||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||||
|
if self.cond_retain_index_list:
|
||||||
|
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return None
|
return None
|
||||||
# reuse or resize cond items to match context requirements
|
# reuse or resize cond items to match context requirements
|
||||||
resized_cond = []
|
resized_cond = []
|
||||||
|
# if multiple conds, split based on primary region
|
||||||
|
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||||
|
region = window.get_region_index(len(cond_in))
|
||||||
|
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
||||||
|
cond_in = [cond_in[region]]
|
||||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||||
for actual_cond in cond_in:
|
for actual_cond in cond_in:
|
||||||
resized_actual_cond = actual_cond.copy()
|
resized_actual_cond = actual_cond.copy()
|
||||||
@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||||
for cond_key, cond_value in new_cond_item.items():
|
for cond_key, cond_value in new_cond_item.items():
|
||||||
if isinstance(cond_value, torch.Tensor):
|
if isinstance(cond_value, torch.Tensor):
|
||||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||||
|
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||||
|
# Handle audio_embed (temporal dim is 1)
|
||||||
|
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
audio_cond = cond_value.cond
|
||||||
|
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||||
# if has cond that is a Tensor, check if needs to be subset
|
# if has cond that is a Tensor, check if needs to be subset
|
||||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||||
elif cond_key == "num_video_frames": # for SVD
|
elif cond_key == "num_video_frames": # for SVD
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||||
new_cond_item[cond_key].cond = window.context_length
|
new_cond_item[cond_key].cond = window.context_length
|
||||||
@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return resized_cond
|
return resized_cond
|
||||||
|
|
||||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||||
matches = torch.nonzero(mask)
|
matches = torch.nonzero(mask)
|
||||||
if torch.numel(matches) == 0:
|
if torch.numel(matches) == 0:
|
||||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||||
@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||||
return context_windows
|
return context_windows
|
||||||
|
|
||||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
prev_weight = (bias_total / (bias_total + bias))
|
prev_weight = (bias_total / (bias_total + bias))
|
||||||
new_weight = (bias / (bias_total + bias))
|
new_weight = (bias / (bias_total + bias))
|
||||||
# account for dims of tensors
|
# account for dims of tensors
|
||||||
idx_window = [slice(None)] * self.dim + [idx]
|
idx_window = tuple([slice(None)] * self.dim + [idx])
|
||||||
pos_window = [slice(None)] * self.dim + [pos]
|
pos_window = tuple([slice(None)] * self.dim + [pos])
|
||||||
# apply new values
|
# apply new values
|
||||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||||
biases_final[i][idx] = bias_total + bias
|
biases_final[i][idx] = bias_total + bias
|
||||||
@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
||||||
|
model_options = extra_args.get("model_options", None)
|
||||||
|
if model_options is None:
|
||||||
|
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||||
|
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||||
|
if handler is None:
|
||||||
|
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||||
|
if not handler.freenoise:
|
||||||
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
|
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||||
|
|
||||||
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||||
|
model.add_wrapper_with_key(
|
||||||
|
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||||
|
"ContextWindows_sampler_sample",
|
||||||
|
_sampler_sample_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||||
total_dims = len(x_in.shape)
|
total_dims = len(x_in.shape)
|
||||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||||
@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
|
|||||||
for i in range(len(window)):
|
for i in range(len(window)):
|
||||||
# 2) add end_delta to each val to slide windows to end
|
# 2) add end_delta to each val to slide windows to end
|
||||||
window[i] = window[i] + end_delta
|
window[i] = window[i] + end_delta
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||||
|
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||||
|
logging.info("Context windows: Applying FreeNoise")
|
||||||
|
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||||
|
latent_video_length = noise.shape[dim]
|
||||||
|
delta = context_length - context_overlap
|
||||||
|
|
||||||
|
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||||
|
place_idx = start_idx + context_length
|
||||||
|
|
||||||
|
actual_delta = min(delta, latent_video_length - place_idx)
|
||||||
|
if actual_delta <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||||
|
|
||||||
|
source_slice = [slice(None)] * noise.ndim
|
||||||
|
source_slice[dim] = list_idx
|
||||||
|
target_slice = [slice(None)] * noise.ndim
|
||||||
|
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||||
|
|
||||||
|
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||||
|
|
||||||
|
return noise
|
||||||
|
|||||||
@ -40,7 +40,8 @@ class ChromaParams:
|
|||||||
out_dim: int
|
out_dim: int
|
||||||
hidden_dim: int
|
hidden_dim: int
|
||||||
n_layers: int
|
n_layers: int
|
||||||
|
txt_ids_dims: list
|
||||||
|
vec_in_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module):
|
|||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return self.out_layer(self.silu(self.in_layer(x)))
|
return self.out_layer(self.silu(self.in_layer(x)))
|
||||||
|
|
||||||
|
class YakMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.act_fn = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||||
|
if yak_mlp:
|
||||||
|
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
if mlp_silu_act:
|
||||||
|
return nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||||
|
SiLUActivation(),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
@ -140,7 +169,7 @@ class SiLUActivation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
if mlp_silu_act:
|
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
self.img_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
|
||||||
SiLUActivation(),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.img_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
||||||
nn.GELU(approximate="tanh"),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
if mlp_silu_act:
|
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
self.txt_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
|
||||||
SiLUActivation(),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.txt_mlp = nn.Sequential(
|
|
||||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
||||||
nn.GELU(approximate="tanh"),
|
|
||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
modulation=True,
|
modulation=True,
|
||||||
mlp_silu_act=False,
|
mlp_silu_act=False,
|
||||||
bias=True,
|
bias=True,
|
||||||
|
yak_mlp=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None
|
||||||
@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
|
||||||
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||||
|
self.yak_mlp = yak_mlp
|
||||||
if mlp_silu_act:
|
if mlp_silu_act:
|
||||||
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||||
self.mlp_act = SiLUActivation()
|
self.mlp_act = SiLUActivation()
|
||||||
else:
|
else:
|
||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
|
||||||
|
if self.yak_mlp:
|
||||||
|
self.mlp_hidden_dim_first *= 2
|
||||||
|
self.mlp_act = nn.SiLU()
|
||||||
|
|
||||||
# qkv and mlp_in
|
# qkv and mlp_in
|
||||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||||
# proj and mlp_out
|
# proj and mlp_out
|
||||||
@ -325,6 +338,9 @@ class SingleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
|
if self.yak_mlp:
|
||||||
|
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||||
|
else:
|
||||||
mlp = self.mlp_act(mlp)
|
mlp = self.mlp_act(mlp)
|
||||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||||
|
|||||||
@ -15,7 +15,8 @@ from .layers import (
|
|||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
Modulation
|
Modulation,
|
||||||
|
RMSNorm
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -34,11 +35,14 @@ class FluxParams:
|
|||||||
patch_size: int
|
patch_size: int
|
||||||
qkv_bias: bool
|
qkv_bias: bool
|
||||||
guidance_embed: bool
|
guidance_embed: bool
|
||||||
|
txt_ids_dims: list
|
||||||
global_modulation: bool = False
|
global_modulation: bool = False
|
||||||
mlp_silu_act: bool = False
|
mlp_silu_act: bool = False
|
||||||
ops_bias: bool = True
|
ops_bias: bool = True
|
||||||
default_ref_method: str = "offset"
|
default_ref_method: str = "offset"
|
||||||
ref_index_scale: float = 1.0
|
ref_index_scale: float = 1.0
|
||||||
|
yak_mlp: bool = False
|
||||||
|
txt_norm: bool = False
|
||||||
|
|
||||||
|
|
||||||
class Flux(nn.Module):
|
class Flux(nn.Module):
|
||||||
@ -76,6 +80,11 @@ class Flux(nn.Module):
|
|||||||
)
|
)
|
||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if params.txt_norm:
|
||||||
|
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.txt_norm = None
|
||||||
|
|
||||||
self.double_blocks = nn.ModuleList(
|
self.double_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DoubleStreamBlock(
|
DoubleStreamBlock(
|
||||||
@ -86,6 +95,7 @@ class Flux(nn.Module):
|
|||||||
modulation=params.global_modulation is False,
|
modulation=params.global_modulation is False,
|
||||||
mlp_silu_act=params.mlp_silu_act,
|
mlp_silu_act=params.mlp_silu_act,
|
||||||
proj_bias=params.ops_bias,
|
proj_bias=params.ops_bias,
|
||||||
|
yak_mlp=params.yak_mlp,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -94,7 +104,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
self.single_blocks = nn.ModuleList(
|
self.single_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
for _ in range(params.depth_single_blocks)
|
for _ in range(params.depth_single_blocks)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -150,6 +160,8 @@ class Flux(nn.Module):
|
|||||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
|
|
||||||
|
if self.txt_norm is not None:
|
||||||
|
txt = self.txt_norm(txt)
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
vec_orig = vec
|
vec_orig = vec
|
||||||
@ -332,8 +344,9 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
if len(self.params.axes_dim) == 4: # Flux 2
|
if len(self.params.txt_ids_dims) > 0:
|
||||||
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
for i in self.params.txt_ids_dims:
|
||||||
|
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
out = out[:, :img_tokens]
|
out = out[:, :img_tokens]
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||||
|
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||||
import model_management, model_patcher
|
import model_management, model_patcher
|
||||||
|
|
||||||
class SRResidualCausalBlock3D(nn.Module):
|
class SRResidualCausalBlock3D(nn.Module):
|
||||||
|
|||||||
@ -1,42 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.ldm.models.autoencoder
|
import comfy.ldm.models.autoencoder
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
class NoPadConv3d(nn.Module):
|
|
||||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
|
||||||
|
|
||||||
x = xl[0]
|
|
||||||
xl.clear()
|
|
||||||
|
|
||||||
if conv_carry_out is not None:
|
|
||||||
to_push = x[:, :, -2:, :, :].clone()
|
|
||||||
conv_carry_out.append(to_push)
|
|
||||||
|
|
||||||
if isinstance(op, NoPadConv3d):
|
|
||||||
if conv_carry_in is None:
|
|
||||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
|
||||||
else:
|
|
||||||
carry_len = conv_carry_in[0].shape[2]
|
|
||||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
|
||||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
|
||||||
|
|
||||||
out = op(x)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class RMS_norm(nn.Module):
|
class RMS_norm(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
|
|||||||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
class DnSmpl(nn.Module):
|
class DnSmpl(nn.Module):
|
||||||
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
|
def __init__(self, ic, oc, tds, refiner_vae, op):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||||
assert oc % fct == 0
|
assert oc % fct == 0
|
||||||
@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class UpSmpl(nn.Module):
|
class UpSmpl(nn.Module):
|
||||||
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
|
def __init__(self, ic, oc, tus, refiner_vae, op):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||||
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
||||||
@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
|
|||||||
|
|
||||||
return h + x
|
return h + x
|
||||||
|
|
||||||
class HunyuanRefinerResnetBlock(ResnetBlock):
|
|
||||||
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
|
|
||||||
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
|
||||||
|
|
||||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
|
||||||
h = x
|
|
||||||
h = [ self.swish(self.norm1(x)) ]
|
|
||||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
|
||||||
|
|
||||||
h = [ self.dropout(self.swish(self.norm2(h))) ]
|
|
||||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
x = self.nin_shortcut(x)
|
|
||||||
|
|
||||||
return x+h
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
||||||
@ -191,7 +144,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
self.refiner_vae = refiner_vae
|
self.refiner_vae = refiner_vae
|
||||||
if self.refiner_vae:
|
if self.refiner_vae:
|
||||||
conv_op = NoPadConv3d
|
conv_op = CarriedConv3d
|
||||||
norm_op = RMS_norm
|
norm_op = RMS_norm
|
||||||
else:
|
else:
|
||||||
conv_op = ops.Conv3d
|
conv_op = ops.Conv3d
|
||||||
@ -206,8 +159,9 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
for i, tgt in enumerate(block_out_channels):
|
for i, tgt in enumerate(block_out_channels):
|
||||||
stage = nn.Module()
|
stage = nn.Module()
|
||||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
conv_op=conv_op, norm_op=norm_op)
|
conv_op=conv_op, norm_op=norm_op)
|
||||||
for j in range(num_res_blocks)])
|
for j in range(num_res_blocks)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
@ -218,9 +172,9 @@ class Encoder(nn.Module):
|
|||||||
self.down.append(stage)
|
self.down.append(stage)
|
||||||
|
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
|
|
||||||
self.norm_out = norm_op(ch)
|
self.norm_out = norm_op(ch)
|
||||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||||
@ -246,22 +200,20 @@ class Encoder(nn.Module):
|
|||||||
conv_carry_out = []
|
conv_carry_out = []
|
||||||
if i == len(x) - 1:
|
if i == len(x) - 1:
|
||||||
conv_carry_out = None
|
conv_carry_out = None
|
||||||
|
|
||||||
x1 = [ x1 ]
|
x1 = [ x1 ]
|
||||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
for stage in self.down:
|
for stage in self.down:
|
||||||
for blk in stage.block:
|
for blk in stage.block:
|
||||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||||
if hasattr(stage, 'downsample'):
|
if hasattr(stage, 'downsample'):
|
||||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
out.append(x1)
|
out.append(x1)
|
||||||
conv_carry_in = conv_carry_out
|
conv_carry_in = conv_carry_out
|
||||||
|
|
||||||
if len(out) > 1:
|
out = torch_cat_if_needed(out, dim=2)
|
||||||
out = torch.cat(out, dim=2)
|
|
||||||
else:
|
|
||||||
out = out[0]
|
|
||||||
|
|
||||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||||
del out
|
del out
|
||||||
@ -288,7 +240,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.refiner_vae = refiner_vae
|
self.refiner_vae = refiner_vae
|
||||||
if self.refiner_vae:
|
if self.refiner_vae:
|
||||||
conv_op = NoPadConv3d
|
conv_op = CarriedConv3d
|
||||||
norm_op = RMS_norm
|
norm_op = RMS_norm
|
||||||
else:
|
else:
|
||||||
conv_op = ops.Conv3d
|
conv_op = ops.Conv3d
|
||||||
@ -298,9 +250,9 @@ class Decoder(nn.Module):
|
|||||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
|
|
||||||
self.up = nn.ModuleList()
|
self.up = nn.ModuleList()
|
||||||
depth = (ffactor_spatial >> 1).bit_length()
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
@ -308,8 +260,9 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
for i, tgt in enumerate(block_out_channels):
|
for i, tgt in enumerate(block_out_channels):
|
||||||
stage = nn.Module()
|
stage = nn.Module()
|
||||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
conv_op=conv_op, norm_op=norm_op)
|
conv_op=conv_op, norm_op=norm_op)
|
||||||
for j in range(num_res_blocks + 1)])
|
for j in range(num_res_blocks + 1)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
@ -340,7 +293,7 @@ class Decoder(nn.Module):
|
|||||||
conv_carry_out = None
|
conv_carry_out = None
|
||||||
for stage in self.up:
|
for stage in self.up:
|
||||||
for blk in stage.block:
|
for blk in stage.block:
|
||||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||||
if hasattr(stage, 'upsample'):
|
if hasattr(stage, 'upsample'):
|
||||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
@ -350,10 +303,7 @@ class Decoder(nn.Module):
|
|||||||
conv_carry_in = conv_carry_out
|
conv_carry_in = conv_carry_out
|
||||||
del x
|
del x
|
||||||
|
|
||||||
if len(out) > 1:
|
out = torch_cat_if_needed(out, dim=2)
|
||||||
out = torch.cat(out, dim=2)
|
|
||||||
else:
|
|
||||||
out = out[0]
|
|
||||||
|
|
||||||
if not self.refiner_vae:
|
if not self.refiner_vae:
|
||||||
if z.shape[-3] == 1:
|
if z.shape[-3] == 1:
|
||||||
|
|||||||
413
comfy/ldm/kandinsky5/model.py
Normal file
413
comfy/ldm/kandinsky5/model.py
Normal file
@ -0,0 +1,413 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
|
||||||
|
def attention(q, k, v, heads, transformer_options={}):
|
||||||
|
return optimized_attention(
|
||||||
|
q.transpose(1, 2),
|
||||||
|
k.transpose(1, 2),
|
||||||
|
v.transpose(1, 2),
|
||||||
|
heads=heads,
|
||||||
|
skip_reshape=True,
|
||||||
|
transformer_options=transformer_options
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||||
|
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||||
|
|
||||||
|
def apply_gate_sum(x, out, gate):
|
||||||
|
return torch.addcmul(x, gate, out)
|
||||||
|
|
||||||
|
def get_shift_scale_gate(params):
|
||||||
|
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
||||||
|
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||||
|
|
||||||
|
def get_freqs(dim, max_period=10000.0):
|
||||||
|
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeEmbeddings(nn.Module):
|
||||||
|
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
assert model_dim % 2 == 0
|
||||||
|
self.model_dim = model_dim
|
||||||
|
self.max_period = max_period
|
||||||
|
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, timestep, dtype):
|
||||||
|
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||||
|
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||||
|
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||||
|
return time_embed
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbeddings(nn.Module):
|
||||||
|
def __init__(self, text_dim, model_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, text_embed):
|
||||||
|
text_embed = self.in_layer(text_embed)
|
||||||
|
return self.norm(text_embed).type_as(text_embed)
|
||||||
|
|
||||||
|
|
||||||
|
class VisualEmbeddings(nn.Module):
|
||||||
|
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||||
|
B, T, H, W, dim = x.shape
|
||||||
|
pt, ph, pw = self.patch_size
|
||||||
|
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
T // pt, pt,
|
||||||
|
H // ph, ph,
|
||||||
|
W // pw, pw,
|
||||||
|
dim,
|
||||||
|
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||||
|
|
||||||
|
return self.in_layer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Modulation(nn.Module):
|
||||||
|
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.out_layer(self.activation(x))
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
assert num_channels % head_dim == 0
|
||||||
|
self.num_heads = num_channels // head_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.num_chunks = 2
|
||||||
|
|
||||||
|
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||||
|
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
return apply_rope1(norm_fn(result), freqs)
|
||||||
|
|
||||||
|
def _forward(self, x, freqs, transformer_options={}):
|
||||||
|
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||||
|
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||||
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||||
|
return self.out_layer(out)
|
||||||
|
|
||||||
|
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||||
|
def process_chunks(proj_fn, norm_fn):
|
||||||
|
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||||
|
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||||
|
chunks = []
|
||||||
|
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||||
|
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||||
|
return torch.cat(chunks, dim=1)
|
||||||
|
|
||||||
|
q = process_chunks(self.to_query, self.query_norm)
|
||||||
|
k = process_chunks(self.to_key, self.key_norm)
|
||||||
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||||
|
return self.out_layer(out)
|
||||||
|
|
||||||
|
def forward(self, x, freqs, transformer_options={}):
|
||||||
|
if x.shape[1] > 8192:
|
||||||
|
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||||
|
else:
|
||||||
|
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(SelfAttention):
|
||||||
|
def get_qkv(self, x, context):
|
||||||
|
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||||
|
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def forward(self, x, context, transformer_options={}):
|
||||||
|
q, k, v = self.get_qkv(x, context)
|
||||||
|
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||||
|
return self.out_layer(out)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, ff_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.num_chunks = 4
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
return self.out_layer(self.activation(self.in_layer(x)))
|
||||||
|
|
||||||
|
def _forward_chunked(self, x):
|
||||||
|
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||||
|
output_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
output_chunks.append(self._forward(chunk))
|
||||||
|
return torch.cat(output_chunks, dim=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.shape[1] > 8192:
|
||||||
|
return self._forward_chunked(x)
|
||||||
|
else:
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class OutLayer(nn.Module):
|
||||||
|
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, visual_embed, time_embed):
|
||||||
|
B, T, H, W, _ = visual_embed.shape
|
||||||
|
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||||
|
scale = scale[:, None, None, None, :]
|
||||||
|
shift = shift[:, None, None, None, :]
|
||||||
|
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
||||||
|
x = self.out_layer(visual_embed)
|
||||||
|
|
||||||
|
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
||||||
|
x = x.view(
|
||||||
|
B, T, H, W,
|
||||||
|
out_dim,
|
||||||
|
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
||||||
|
)
|
||||||
|
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderBlock(nn.Module):
|
||||||
|
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
|
||||||
|
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||||
|
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||||
|
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||||
|
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||||
|
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||||
|
x = apply_gate_sum(x, out, gate)
|
||||||
|
|
||||||
|
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||||
|
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||||
|
out = self.feed_forward(out)
|
||||||
|
x = apply_gate_sum(x, out, gate)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderBlock(nn.Module):
|
||||||
|
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||||
|
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||||
|
# self attention
|
||||||
|
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||||
|
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||||
|
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||||
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||||
|
# cross attention
|
||||||
|
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||||
|
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||||
|
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||||
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||||
|
# feed forward
|
||||||
|
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||||
|
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||||
|
visual_out = self.feed_forward(visual_out)
|
||||||
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||||
|
return visual_embed
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
||||||
|
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
||||||
|
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
||||||
|
dtype=None, device=None, operations=None, **kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
head_dim = sum(axes_dims)
|
||||||
|
self.rope_scale_factor = rope_scale_factor
|
||||||
|
self.in_visual_dim = in_visual_dim
|
||||||
|
self.model_dim = model_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.visual_embed_dim = visual_embed_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
||||||
|
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
||||||
|
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
||||||
|
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.text_transformer_blocks = nn.ModuleList(
|
||||||
|
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.visual_transformer_blocks = nn.ModuleList(
|
||||||
|
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
||||||
|
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
||||||
|
|
||||||
|
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
||||||
|
steps = seq_len if steps is None else steps
|
||||||
|
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
||||||
|
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
||||||
|
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||||
|
|
||||||
|
patch_size = self.patch_size
|
||||||
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||||
|
|
||||||
|
if steps_t is None:
|
||||||
|
steps_t = t_len
|
||||||
|
if steps_h is None:
|
||||||
|
steps_h = h_len
|
||||||
|
if steps_w is None:
|
||||||
|
steps_w = w_len
|
||||||
|
|
||||||
|
h_start = 0
|
||||||
|
w_start = 0
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
if rope_options is not None:
|
||||||
|
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
|
||||||
|
t_start += rope_options.get("shift_t", 0.0)
|
||||||
|
h_start += rope_options.get("shift_y", 0.0)
|
||||||
|
w_start += rope_options.get("shift_x", 0.0)
|
||||||
|
else:
|
||||||
|
rope_scale_factor = self.rope_scale_factor
|
||||||
|
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||||
|
if h * w >= 14080:
|
||||||
|
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||||
|
|
||||||
|
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||||
|
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||||
|
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||||
|
|
||||||
|
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||||
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||||
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||||
|
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||||
|
|
||||||
|
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
context = self.text_embeddings(context)
|
||||||
|
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||||
|
|
||||||
|
for block in self.text_transformer_blocks:
|
||||||
|
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
visual_embed = self.visual_embeddings(x)
|
||||||
|
visual_shape = visual_embed.shape[:-1]
|
||||||
|
visual_embed = visual_embed.flatten(1, -2)
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
|
for i, block in enumerate(self.visual_transformer_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||||
|
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||||
|
else:
|
||||||
|
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||||
|
return self.out_layer(visual_embed, time_embed)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||||
|
original_dims = x.ndim
|
||||||
|
if original_dims == 4:
|
||||||
|
x = x.unsqueeze(2)
|
||||||
|
bs, c, t_len, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
|
if time_dim_replace is not None:
|
||||||
|
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||||
|
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||||
|
|
||||||
|
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
|
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||||
|
if original_dims == 4:
|
||||||
|
out = out.squeeze(2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||||
113
comfy/ldm/lumina/controlnet.py
Normal file
113
comfy/ldm/lumina/controlnet.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .model import JointTransformerBlock
|
||||||
|
|
||||||
|
class ZImageControlTransformerBlock(JointTransformerBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: float,
|
||||||
|
norm_eps: float,
|
||||||
|
qk_norm: bool,
|
||||||
|
modulation=True,
|
||||||
|
block_id=0,
|
||||||
|
operation_settings=None,
|
||||||
|
):
|
||||||
|
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, c, x, **kwargs):
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
c = super().forward(c, **kwargs)
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
return c_skip, c
|
||||||
|
|
||||||
|
class ZImage_Control(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int = 3840,
|
||||||
|
n_heads: int = 30,
|
||||||
|
n_kv_heads: int = 30,
|
||||||
|
multiple_of: int = 256,
|
||||||
|
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
self.additional_in_dim = 0
|
||||||
|
self.control_in_dim = 16
|
||||||
|
n_refiner_layers = 2
|
||||||
|
self.n_control_layers = 6
|
||||||
|
self.control_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageControlTransformerBlock(
|
||||||
|
i,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
block_id=i,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for i in range(self.n_control_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
all_x_embedder = {}
|
||||||
|
patch_size = 2
|
||||||
|
f_patch_size = 1
|
||||||
|
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
||||||
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||||
|
|
||||||
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||||
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
JointTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
z_image_modulation=True,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||||
|
patch_size = 2
|
||||||
|
f_patch_size = 1
|
||||||
|
pH = pW = patch_size
|
||||||
|
B, C, H, W = control_context.shape
|
||||||
|
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||||
|
|
||||||
|
x_attn_mask = None
|
||||||
|
for layer in self.control_noise_refiner:
|
||||||
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||||
|
return control_context
|
||||||
|
|
||||||
|
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
@ -22,6 +22,10 @@ def modulate(x, scale):
|
|||||||
# Core NextDiT Model #
|
# Core NextDiT Model #
|
||||||
#############################################################################
|
#############################################################################
|
||||||
|
|
||||||
|
def clamp_fp16(x):
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
return x
|
||||||
|
|
||||||
class JointAttention(nn.Module):
|
class JointAttention(nn.Module):
|
||||||
"""Multi-head attention module."""
|
"""Multi-head attention module."""
|
||||||
@ -169,7 +173,7 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
# @torch.compile
|
# @torch.compile
|
||||||
def _forward_silu_gating(self, x1, x3):
|
def _forward_silu_gating(self, x1, x3):
|
||||||
return F.silu(x1) * x3
|
return clamp_fp16(F.silu(x1) * x3)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||||
@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module):
|
|||||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||||
|
|
||||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||||
self.attention(
|
clamp_fp16(self.attention(
|
||||||
modulate(self.attention_norm1(x), scale_msa),
|
modulate(self.attention_norm1(x), scale_msa),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
))
|
||||||
)
|
)
|
||||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||||
self.feed_forward(
|
clamp_fp16(self.feed_forward(
|
||||||
modulate(self.ffn_norm1(x), scale_mlp),
|
modulate(self.ffn_norm1(x), scale_mlp),
|
||||||
)
|
))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert adaln_input is None
|
assert adaln_input is None
|
||||||
x = x + self.attention_norm2(
|
x = x + self.attention_norm2(
|
||||||
self.attention(
|
clamp_fp16(self.attention(
|
||||||
self.attention_norm1(x),
|
self.attention_norm1(x),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
))
|
||||||
)
|
)
|
||||||
x = x + self.ffn_norm2(
|
x = x + self.ffn_norm2(
|
||||||
self.feed_forward(
|
self.feed_forward(
|
||||||
@ -564,7 +568,7 @@ class NextDiT(nn.Module):
|
|||||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||||
|
|
||||||
# def forward(self, x, t, cap_feats, cap_mask):
|
# def forward(self, x, t, cap_feats, cap_mask):
|
||||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||||
t = 1.0 - timesteps
|
t = 1.0 - timesteps
|
||||||
cap_feats = context
|
cap_feats = context
|
||||||
cap_mask = attention_mask
|
cap_mask = attention_mask
|
||||||
@ -581,16 +585,23 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|
||||||
transformer_options = kwargs.get("transformer_options", {})
|
patches = transformer_options.get("patches", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(x.device)
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
for layer in self.layers:
|
for i, layer in enumerate(self.layers):
|
||||||
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
|
if "double_block" in patches:
|
||||||
|
for p in patches["double_block"]:
|
||||||
|
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
|
if "img" in out:
|
||||||
|
img[:, cap_size[0]:] = out["img"]
|
||||||
|
if "txt" in out:
|
||||||
|
img[:, :cap_size[0]] = out["txt"]
|
||||||
|
|
||||||
x = self.final_layer(x, adaln_input)
|
img = self.final_layer(img, adaln_input)
|
||||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||||
|
|
||||||
return -x
|
return -img
|
||||||
|
|
||||||
|
|||||||
@ -517,6 +517,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
exception_fallback = False
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout = "HND"
|
tensor_layout = "HND"
|
||||||
@ -541,6 +542,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||||
|
exception_fallback = True
|
||||||
|
if exception_fallback:
|
||||||
if tensor_layout == "NHD":
|
if tensor_layout == "NHD":
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.transpose(1, 2),
|
lambda t: t.transpose(1, 2),
|
||||||
|
|||||||
@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
|
def torch_cat_if_needed(xl, dim):
|
||||||
|
if len(xl) > 1:
|
||||||
|
return torch.cat(xl, dim)
|
||||||
|
else:
|
||||||
|
return xl[0]
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
|
|||||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CarriedConv3d(nn.Module):
|
||||||
|
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||||
|
|
||||||
|
x = xl[0]
|
||||||
|
xl.clear()
|
||||||
|
|
||||||
|
if isinstance(op, CarriedConv3d):
|
||||||
|
if conv_carry_in is None:
|
||||||
|
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||||
|
else:
|
||||||
|
carry_len = conv_carry_in[0].shape[2]
|
||||||
|
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||||
|
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||||
|
|
||||||
|
if conv_carry_out is not None:
|
||||||
|
to_push = x[:, :, -2:, :, :].clone()
|
||||||
|
conv_carry_out.append(to_push)
|
||||||
|
|
||||||
|
out = op(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class VideoConv3d(nn.Module):
|
class VideoConv3d(nn.Module):
|
||||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -89,29 +126,24 @@ class Upsample(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||||
scale_factor = self.scale_factor
|
scale_factor = self.scale_factor
|
||||||
if isinstance(scale_factor, (int, float)):
|
if isinstance(scale_factor, (int, float)):
|
||||||
scale_factor = (scale_factor,) * (x.ndim - 2)
|
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||||
|
|
||||||
if x.ndim == 5 and scale_factor[0] > 1.0:
|
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||||
t = x.shape[2]
|
results = []
|
||||||
if t > 1:
|
if conv_carry_in is None:
|
||||||
a, b = x.split((1, t - 1), dim=2)
|
first = x[:, :, :1, :, :]
|
||||||
del x
|
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
|
||||||
b = interpolate_up(b, scale_factor)
|
x = x[:, :, 1:, :, :]
|
||||||
else:
|
if x.shape[2] > 0:
|
||||||
a = x
|
results.append(interpolate_up(x, scale_factor))
|
||||||
|
x = torch_cat_if_needed(results, dim=2)
|
||||||
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
|
||||||
if t > 1:
|
|
||||||
x = torch.cat((a, b), dim=2)
|
|
||||||
else:
|
|
||||||
x = a
|
|
||||||
else:
|
else:
|
||||||
x = interpolate_up(x, scale_factor)
|
x = interpolate_up(x, scale_factor)
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
x = self.conv(x)
|
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -127,12 +159,15 @@ class Downsample(nn.Module):
|
|||||||
stride=stride,
|
stride=stride,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
if x.ndim == 4:
|
if isinstance(self.conv, CarriedConv3d):
|
||||||
|
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||||
|
elif x.ndim == 4:
|
||||||
pad = (0, 1, 0, 1)
|
pad = (0, 1, 0, 1)
|
||||||
mode = "constant"
|
mode = "constant"
|
||||||
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||||
|
x = self.conv(x)
|
||||||
elif x.ndim == 5:
|
elif x.ndim == 5:
|
||||||
pad = (1, 1, 1, 1, 2, 0)
|
pad = (1, 1, 1, 1, 2, 0)
|
||||||
mode = "replicate"
|
mode = "replicate"
|
||||||
@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
def forward(self, x, temb=None):
|
def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
|
||||||
h = x
|
h = x
|
||||||
h = self.norm1(h)
|
h = self.norm1(h)
|
||||||
h = self.swish(h)
|
h = [ self.swish(h) ]
|
||||||
h = self.conv1(h)
|
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
|
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
||||||
|
|
||||||
h = self.norm2(h)
|
h = self.norm2(h)
|
||||||
h = self.swish(h)
|
h = self.swish(h)
|
||||||
h = self.dropout(h)
|
h = [ self.dropout(h) ]
|
||||||
h = self.conv2(h)
|
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
x = self.conv_shortcut(x)
|
x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
else:
|
else:
|
||||||
x = self.nin_shortcut(x)
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
@ -279,6 +314,7 @@ def pytorch_attention(q, k, v):
|
|||||||
orig_shape = q.shape
|
orig_shape = q.shape
|
||||||
B = orig_shape[0]
|
B = orig_shape[0]
|
||||||
C = orig_shape[1]
|
C = orig_shape[1]
|
||||||
|
oom_fallback = False
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
@ -289,6 +325,8 @@ def pytorch_attention(q, k, v):
|
|||||||
out = out.transpose(2, 3).reshape(orig_shape)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
|
oom_fallback = True
|
||||||
|
if oom_fallback:
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -517,8 +555,13 @@ class Encoder(nn.Module):
|
|||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
self.carried = False
|
||||||
|
|
||||||
if conv3d:
|
if conv3d:
|
||||||
|
if not attn_resolutions:
|
||||||
|
conv_op = CarriedConv3d
|
||||||
|
self.carried = True
|
||||||
|
else:
|
||||||
conv_op = VideoConv3d
|
conv_op = VideoConv3d
|
||||||
mid_attn_conv_op = ops.Conv3d
|
mid_attn_conv_op = ops.Conv3d
|
||||||
else:
|
else:
|
||||||
@ -532,6 +575,7 @@ class Encoder(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
|
self.time_compress = 1
|
||||||
curr_res = resolution
|
curr_res = resolution
|
||||||
in_ch_mult = (1,)+tuple(ch_mult)
|
in_ch_mult = (1,)+tuple(ch_mult)
|
||||||
self.in_ch_mult = in_ch_mult
|
self.in_ch_mult = in_ch_mult
|
||||||
@ -558,10 +602,15 @@ class Encoder(nn.Module):
|
|||||||
if time_compress is not None:
|
if time_compress is not None:
|
||||||
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||||
stride = (1, 2, 2)
|
stride = (1, 2, 2)
|
||||||
|
else:
|
||||||
|
self.time_compress *= 2
|
||||||
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||||
curr_res = curr_res // 2
|
curr_res = curr_res // 2
|
||||||
self.down.append(down)
|
self.down.append(down)
|
||||||
|
|
||||||
|
if time_compress is not None:
|
||||||
|
self.time_compress = time_compress
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||||
@ -587,15 +636,42 @@ class Encoder(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# timestep embedding
|
# timestep embedding
|
||||||
temb = None
|
temb = None
|
||||||
|
|
||||||
|
if self.carried:
|
||||||
|
xl = [x[:, :, :1, :, :]]
|
||||||
|
if x.shape[2] > self.time_compress:
|
||||||
|
tc = self.time_compress
|
||||||
|
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
|
||||||
|
x = xl
|
||||||
|
else:
|
||||||
|
x = [x]
|
||||||
|
out = []
|
||||||
|
|
||||||
|
conv_carry_in = None
|
||||||
|
|
||||||
|
for i, x1 in enumerate(x):
|
||||||
|
conv_carry_out = []
|
||||||
|
if i == len(x) - 1:
|
||||||
|
conv_carry_out = None
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
h = self.conv_in(x)
|
x1 = [ x1 ]
|
||||||
|
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
for i_level in range(self.num_resolutions):
|
for i_level in range(self.num_resolutions):
|
||||||
for i_block in range(self.num_res_blocks):
|
for i_block in range(self.num_res_blocks):
|
||||||
h = self.down[i_level].block[i_block](h, temb)
|
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
|
||||||
if len(self.down[i_level].attn) > 0:
|
if len(self.down[i_level].attn) > 0:
|
||||||
h = self.down[i_level].attn[i_block](h)
|
assert i == 0 #carried should not happen if attn exists
|
||||||
|
h1 = self.down[i_level].attn[i_block](h1)
|
||||||
if i_level != self.num_resolutions-1:
|
if i_level != self.num_resolutions-1:
|
||||||
h = self.down[i_level].downsample(h)
|
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
|
out.append(h1)
|
||||||
|
conv_carry_in = conv_carry_out
|
||||||
|
|
||||||
|
h = torch_cat_if_needed(out, dim=2)
|
||||||
|
del out
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
h = self.mid.block_1(h, temb)
|
h = self.mid.block_1(h, temb)
|
||||||
@ -604,15 +680,15 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
h = self.norm_out(h)
|
h = self.norm_out(h)
|
||||||
h = nonlinearity(h)
|
h = [ nonlinearity(h) ]
|
||||||
h = self.conv_out(h)
|
h = conv_carry_causal_3d(h, self.conv_out)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
resolution, z_channels, tanh_out=False, use_linear_attn=False,
|
||||||
conv_out_op=ops.Conv2d,
|
conv_out_op=ops.Conv2d,
|
||||||
resnet_op=ResnetBlock,
|
resnet_op=ResnetBlock,
|
||||||
attn_op=AttnBlock,
|
attn_op=AttnBlock,
|
||||||
@ -626,12 +702,18 @@ class Decoder(nn.Module):
|
|||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.give_pre_end = give_pre_end
|
|
||||||
self.tanh_out = tanh_out
|
self.tanh_out = tanh_out
|
||||||
|
self.carried = False
|
||||||
|
|
||||||
if conv3d:
|
if conv3d:
|
||||||
|
if not attn_resolutions and resnet_op == ResnetBlock:
|
||||||
|
conv_op = CarriedConv3d
|
||||||
|
conv_out_op = CarriedConv3d
|
||||||
|
self.carried = True
|
||||||
|
else:
|
||||||
conv_op = VideoConv3d
|
conv_op = VideoConv3d
|
||||||
conv_out_op = VideoConv3d
|
conv_out_op = VideoConv3d
|
||||||
|
|
||||||
mid_attn_conv_op = ops.Conv3d
|
mid_attn_conv_op = ops.Conv3d
|
||||||
else:
|
else:
|
||||||
conv_op = ops.Conv2d
|
conv_op = ops.Conv2d
|
||||||
@ -706,29 +788,43 @@ class Decoder(nn.Module):
|
|||||||
temb = None
|
temb = None
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
h = self.conv_in(z)
|
h = conv_carry_causal_3d([z], self.conv_in)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
h = self.mid.block_1(h, temb, **kwargs)
|
h = self.mid.block_1(h, temb, **kwargs)
|
||||||
h = self.mid.attn_1(h, **kwargs)
|
h = self.mid.attn_1(h, **kwargs)
|
||||||
h = self.mid.block_2(h, temb, **kwargs)
|
h = self.mid.block_2(h, temb, **kwargs)
|
||||||
|
|
||||||
|
if self.carried:
|
||||||
|
h = torch.split(h, 2, dim=2)
|
||||||
|
else:
|
||||||
|
h = [ h ]
|
||||||
|
out = []
|
||||||
|
|
||||||
|
conv_carry_in = None
|
||||||
|
|
||||||
# upsampling
|
# upsampling
|
||||||
|
for i, h1 in enumerate(h):
|
||||||
|
conv_carry_out = []
|
||||||
|
if i == len(h) - 1:
|
||||||
|
conv_carry_out = None
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
for i_block in range(self.num_res_blocks+1):
|
for i_block in range(self.num_res_blocks+1):
|
||||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
|
||||||
if len(self.up[i_level].attn) > 0:
|
if len(self.up[i_level].attn) > 0:
|
||||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
assert i == 0 #carried should not happen if attn exists
|
||||||
|
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
h = self.up[i_level].upsample(h)
|
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
# end
|
h1 = self.norm_out(h1)
|
||||||
if self.give_pre_end:
|
h1 = [ nonlinearity(h1) ]
|
||||||
return h
|
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.conv_out(h, **kwargs)
|
|
||||||
if self.tanh_out:
|
if self.tanh_out:
|
||||||
h = torch.tanh(h)
|
h1 = torch.tanh(h1)
|
||||||
return h
|
out.append(h1)
|
||||||
|
conv_carry_in = conv_carry_out
|
||||||
|
|
||||||
|
out = torch_cat_if_needed(out, dim=2)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
@ -322,6 +322,13 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
|
key_map["{}".format(key_lora)] = k
|
||||||
|
key_map["transformer.{}".format(key_lora)] = k
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model
|
|||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
|
import comfy.ldm.kandinsky5.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -134,7 +135,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", False)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
@ -329,18 +330,6 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||||
|
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
|
||||||
if self.model_config.scaled_fp8 is not None:
|
|
||||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
|
||||||
|
|
||||||
# Save mixed precision metadata
|
|
||||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
|
||||||
metadata = {
|
|
||||||
"format_version": "1.0",
|
|
||||||
"layers": self.model_config.layer_quant_config
|
|
||||||
}
|
|
||||||
unet_state_dict["_quantization_metadata"] = metadata
|
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
@ -1642,3 +1631,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
|
|||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Kandinsky5(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
return kwargs["pooled_output"]
|
||||||
|
|
||||||
|
def concat_cond(self, **kwargs):
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
image = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.zeros_like(noise)[:, :1]
|
||||||
|
else:
|
||||||
|
mask = 1.0 - mask
|
||||||
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
if mask.shape[-3] < noise.shape[-3]:
|
||||||
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||||
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
|
||||||
|
return torch.cat((image, mask), dim=1)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
time_dim_replace = kwargs.get("time_dim_replace", None)
|
||||||
|
if time_dim_replace is not None:
|
||||||
|
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Kandinsky5Image(Kandinsky5):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
|
||||||
|
def concat_cond(self, **kwargs):
|
||||||
|
return None
|
||||||
|
|||||||
@ -6,20 +6,6 @@ import math
|
|||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def detect_layer_quantization(metadata):
|
|
||||||
quant_key = "_quantization_metadata"
|
|
||||||
if metadata is not None and quant_key in metadata:
|
|
||||||
quant_metadata = metadata.pop(quant_key)
|
|
||||||
quant_metadata = json.loads(quant_metadata)
|
|
||||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
|
||||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
|
||||||
return quant_metadata["layers"]
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid quantization metadata format")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
count = 0
|
count = 0
|
||||||
while True:
|
while True:
|
||||||
@ -208,12 +194,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["theta"] = 2000
|
dit_config["theta"] = 2000
|
||||||
dit_config["out_channels"] = 128
|
dit_config["out_channels"] = 128
|
||||||
dit_config["global_modulation"] = True
|
dit_config["global_modulation"] = True
|
||||||
dit_config["vec_in_dim"] = None
|
|
||||||
dit_config["mlp_silu_act"] = True
|
dit_config["mlp_silu_act"] = True
|
||||||
dit_config["qkv_bias"] = False
|
dit_config["qkv_bias"] = False
|
||||||
dit_config["ops_bias"] = False
|
dit_config["ops_bias"] = False
|
||||||
dit_config["default_ref_method"] = "index"
|
dit_config["default_ref_method"] = "index"
|
||||||
dit_config["ref_index_scale"] = 10.0
|
dit_config["ref_index_scale"] = 10.0
|
||||||
|
dit_config["txt_ids_dims"] = [3]
|
||||||
patch_size = 1
|
patch_size = 1
|
||||||
else:
|
else:
|
||||||
dit_config["image_model"] = "flux"
|
dit_config["image_model"] = "flux"
|
||||||
@ -223,6 +209,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["theta"] = 10000
|
dit_config["theta"] = 10000
|
||||||
dit_config["out_channels"] = 16
|
dit_config["out_channels"] = 16
|
||||||
dit_config["qkv_bias"] = True
|
dit_config["qkv_bias"] = True
|
||||||
|
dit_config["txt_ids_dims"] = []
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
|
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = 16
|
||||||
@ -245,6 +232,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||||
if vec_in_key in state_dict_keys:
|
if vec_in_key in state_dict_keys:
|
||||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||||
|
else:
|
||||||
|
dit_config["vec_in_dim"] = None
|
||||||
|
|
||||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
@ -270,6 +259,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
|
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||||
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||||
@ -617,6 +611,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||||
|
dit_config = {}
|
||||||
|
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["model_dim"] = model_dim
|
||||||
|
if model_dim in [4096, 2560]: # pro video and lite image
|
||||||
|
dit_config["axes_dims"] = (32, 48, 48)
|
||||||
|
if model_dim == 2560: # lite image
|
||||||
|
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
|
||||||
|
elif model_dim == 1792: # lite video
|
||||||
|
dit_config["axes_dims"] = (16, 24, 24)
|
||||||
|
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["image_model"] = "kandinsky5"
|
||||||
|
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -759,22 +771,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||||
|
|
||||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
|
||||||
if scaled_fp8_key in state_dict:
|
|
||||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
|
||||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
|
||||||
if model_config.scaled_fp8 == torch.float32:
|
|
||||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
|
||||||
if scaled_fp8_weight.nelement() == 2:
|
|
||||||
model_config.optimizations["fp8"] = False
|
|
||||||
else:
|
|
||||||
model_config.optimizations["fp8"] = True
|
|
||||||
|
|
||||||
# Detect per-layer quantization (mixed precision)
|
# Detect per-layer quantization (mixed precision)
|
||||||
layer_quant_config = detect_layer_quantization(metadata)
|
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
|
||||||
if layer_quant_config:
|
if quant_config:
|
||||||
model_config.layer_quant_config = layer_quant_config
|
model_config.quant_config = quant_config
|
||||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
logging.info("Detected mixed precision quantization")
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|||||||
@ -1492,6 +1492,20 @@ def extended_fp16_support():
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
LORA_COMPUTE_DTYPES = {}
|
||||||
|
def lora_compute_dtype(device):
|
||||||
|
dtype = LORA_COMPUTE_DTYPES.get(device, None)
|
||||||
|
if dtype is not None:
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
if should_use_fp16(device):
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
LORA_COMPUTE_DTYPES[device] = dtype
|
||||||
|
return dtype
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
|||||||
@ -126,27 +126,11 @@ class LowVramPatch:
|
|||||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
self.convert_func = convert_func
|
self.convert_func = convert_func # TODO: remove
|
||||||
self.set_func = set_func
|
self.set_func = set_func
|
||||||
|
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
intermediate_dtype = weight.dtype
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
if self.convert_func is not None:
|
|
||||||
weight = self.convert_func(weight, inplace=False)
|
|
||||||
|
|
||||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
|
||||||
intermediate_dtype = torch.float32
|
|
||||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
|
||||||
if self.set_func is None:
|
|
||||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
|
||||||
else:
|
|
||||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
|
||||||
|
|
||||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
|
||||||
if self.set_func is not None:
|
|
||||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
|
||||||
else:
|
|
||||||
return out
|
|
||||||
|
|
||||||
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||||
@ -630,10 +614,11 @@ class ModelPatcher:
|
|||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||||
|
|
||||||
|
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(temp_dtype, copy=True)
|
||||||
if convert_func is not None:
|
if convert_func is not None:
|
||||||
temp_weight = convert_func(temp_weight, inplace=True)
|
temp_weight = convert_func(temp_weight, inplace=True)
|
||||||
|
|
||||||
@ -699,12 +684,12 @@ class ModelPatcher:
|
|||||||
offloaded = []
|
offloaded = []
|
||||||
offload_buffer = 0
|
offload_buffer = 0
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for i, x in enumerate(loading):
|
||||||
module_offload_mem, module_mem, n, m, params = x
|
module_offload_mem, module_mem, n, m, params = x
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
|
potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
|
||||||
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
@ -777,6 +762,8 @@ class ModelPatcher:
|
|||||||
key = "{}.{}".format(n, param)
|
key = "{}.{}".format(n, param)
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
self.patch_weight_to_device(key, device_to=device_to)
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
if comfy.model_management.is_device_cuda(device_to):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
@ -876,14 +863,18 @@ class ModelPatcher:
|
|||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
unload_list = self._load_list()
|
unload_list = self._load_list()
|
||||||
unload_list.sort()
|
unload_list.sort()
|
||||||
|
|
||||||
offload_buffer = self.model.model_offload_buffer_memory
|
offload_buffer = self.model.model_offload_buffer_memory
|
||||||
|
if len(unload_list) > 0:
|
||||||
|
NS = comfy.model_management.NUM_STREAMS
|
||||||
|
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
||||||
|
|
||||||
for unload in unload_list:
|
for unload in unload_list:
|
||||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||||
break
|
break
|
||||||
module_offload_mem, module_mem, n, m, params = unload
|
module_offload_mem, module_mem, n, m, params = unload
|
||||||
|
|
||||||
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
|
potential_offload = module_offload_mem + sum(offload_weight_factor)
|
||||||
|
|
||||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
@ -935,6 +926,8 @@ class ModelPatcher:
|
|||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
offload_buffer = max(offload_buffer, potential_offload)
|
offload_buffer = max(offload_buffer, potential_offload)
|
||||||
|
offload_weight_factor.append(module_mem)
|
||||||
|
offload_weight_factor.pop(0)
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
|
|||||||
164
comfy/ops.py
164
comfy/ops.py
@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import json
|
||||||
|
|
||||||
def run_every_op():
|
def run_every_op():
|
||||||
if torch.compiler.is_compiling():
|
if torch.compiler.is_compiling():
|
||||||
@ -111,22 +112,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||||
|
|
||||||
if bias_has_function:
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
with wf_context:
|
|
||||||
|
bias_a = bias
|
||||||
|
weight_a = weight
|
||||||
|
|
||||||
|
if s.bias is not None:
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
if weight_has_function or weight.dtype != dtype:
|
if weight_has_function or weight.dtype != dtype:
|
||||||
with wf_context:
|
|
||||||
weight = weight.to(dtype=dtype)
|
weight = weight.to(dtype=dtype)
|
||||||
if isinstance(weight, QuantizedTensor):
|
if isinstance(weight, QuantizedTensor):
|
||||||
weight = weight.dequantize()
|
weight = weight.dequantize()
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
|
||||||
if offloadable:
|
if offloadable:
|
||||||
return weight, bias, offload_stream
|
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||||
else:
|
else:
|
||||||
#Legacy function signature
|
#Legacy function signature
|
||||||
return weight, bias
|
return weight, bias
|
||||||
@ -135,13 +138,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||||
if offload_stream is None:
|
if offload_stream is None:
|
||||||
return
|
return
|
||||||
if weight is not None:
|
os, weight_a, bias_a = offload_stream
|
||||||
device = weight.device
|
if os is None:
|
||||||
else:
|
|
||||||
if bias is None:
|
|
||||||
return
|
return
|
||||||
device = bias.device
|
if weight_a is not None:
|
||||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
device = weight_a.device
|
||||||
|
else:
|
||||||
|
if bias_a is None:
|
||||||
|
return
|
||||||
|
device = bias_a.device
|
||||||
|
os.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
@ -417,22 +423,12 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if input.ndim == 3 or input.ndim == 2:
|
if input.ndim == 3 or input.ndim == 2:
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
|
||||||
scale_input = self.scale_input
|
|
||||||
if scale_weight is None:
|
|
||||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
else:
|
|
||||||
scale_weight = scale_weight.to(input.device)
|
|
||||||
|
|
||||||
if scale_input is None:
|
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||||
else:
|
|
||||||
scale_input = scale_input.to(input.device)
|
|
||||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
@ -453,7 +449,7 @@ class fp8_ops(manual_cast):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if not self.training:
|
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
|
||||||
try:
|
try:
|
||||||
out = fp8_linear(self, input)
|
out = fp8_linear(self, input)
|
||||||
if out is not None:
|
if out is not None:
|
||||||
@ -466,59 +462,6 @@ class fp8_ops(manual_cast):
|
|||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
|
||||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
|
||||||
class scaled_fp8_op(manual_cast):
|
|
||||||
class Linear(manual_cast.Linear):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
if override_dtype is not None:
|
|
||||||
kwargs['dtype'] = override_dtype
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
if not hasattr(self, 'scale_weight'):
|
|
||||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
|
||||||
|
|
||||||
if not scale_input:
|
|
||||||
self.scale_input = None
|
|
||||||
|
|
||||||
if not hasattr(self, 'scale_input'):
|
|
||||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
|
||||||
if fp8_matrix_mult:
|
|
||||||
out = fp8_linear(self, input)
|
|
||||||
if out is not None:
|
|
||||||
return out
|
|
||||||
|
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
||||||
|
|
||||||
if weight.numel() < input.numel(): #TODO: optimize
|
|
||||||
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
|
||||||
else:
|
|
||||||
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
|
||||||
if inplace:
|
|
||||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
|
||||||
return weight
|
|
||||||
else:
|
|
||||||
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
|
||||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
|
||||||
if return_weight:
|
|
||||||
return weight
|
|
||||||
if inplace_update:
|
|
||||||
self.weight.data.copy_(weight)
|
|
||||||
else:
|
|
||||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
|
|
||||||
return scaled_fp8_op
|
|
||||||
|
|
||||||
CUBLAS_IS_AVAILABLE = False
|
CUBLAS_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from cublas_ops import CublasLinear
|
from cublas_ops import CublasLinear
|
||||||
@ -545,9 +488,9 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
_layer_quant_config = layer_quant_config
|
_quant_config = quant_config
|
||||||
_compute_dtype = compute_dtype
|
_compute_dtype = compute_dtype
|
||||||
_full_precision_mm = full_precision_mm
|
_full_precision_mm = full_precision_mm
|
||||||
|
|
||||||
@ -590,27 +533,38 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
|
|
||||||
manually_loaded_keys = [weight_key]
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||||
|
if layer_conf is not None:
|
||||||
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
|
if layer_conf is None:
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
else:
|
else:
|
||||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
self.quant_format = layer_conf.get("format", None)
|
||||||
if quant_format is None:
|
if not self._full_precision_mm:
|
||||||
|
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
||||||
|
|
||||||
|
if self.quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
qconfig = QUANT_ALGOS[quant_format]
|
qconfig = QUANT_ALGOS[self.quant_format]
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
|
||||||
weight_scale_key = f"{prefix}weight_scale"
|
weight_scale_key = f"{prefix}weight_scale"
|
||||||
|
scale = state_dict.pop(weight_scale_key, None)
|
||||||
|
if scale is not None:
|
||||||
|
scale = scale.to(device)
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': state_dict.pop(weight_scale_key, None),
|
'scale': scale,
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||||
'block_size': qconfig.get("group_size", None),
|
'block_size': qconfig.get("group_size", None),
|
||||||
}
|
}
|
||||||
if layout_params['scale'] is not None:
|
|
||||||
|
if scale is not None:
|
||||||
manually_loaded_keys.append(weight_scale_key)
|
manually_loaded_keys.append(weight_scale_key)
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -619,7 +573,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
continue
|
continue
|
||||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||||
manually_loaded_keys.append(param_key)
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
@ -628,6 +582,16 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
if key in missing_keys:
|
if key in missing_keys:
|
||||||
missing_keys.remove(key)
|
missing_keys.remove(key)
|
||||||
|
|
||||||
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
|
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||||
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
|
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
||||||
|
quant_conf = {"format": self.quant_format}
|
||||||
|
if self._full_precision_mm:
|
||||||
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
|
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
||||||
|
return sd
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
@ -643,9 +607,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
getattr(self, 'input_scale', None) is not None and
|
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
@ -656,7 +619,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
if getattr(self, 'layout_type', None) is not None:
|
if getattr(self, 'layout_type', None) is not None:
|
||||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||||
else:
|
else:
|
||||||
weight = weight.to(self.weight.dtype)
|
weight = weight.to(self.weight.dtype)
|
||||||
if return_weight:
|
if return_weight:
|
||||||
@ -665,17 +628,28 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
|
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||||
|
if recurse:
|
||||||
|
for module in self.children():
|
||||||
|
module._apply(fn)
|
||||||
|
|
||||||
|
for key, param in self._parameters.items():
|
||||||
|
if param is None:
|
||||||
|
continue
|
||||||
|
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||||
|
for key, buf in self._buffers.items():
|
||||||
|
if buf is not None:
|
||||||
|
self._buffers[key] = fn(buf)
|
||||||
|
return self
|
||||||
|
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
logging.info("Using mixed precision operations")
|
||||||
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||||
|
|
||||||
if scaled_fp8 is not None:
|
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fp8_compute and
|
fp8_compute and
|
||||||
|
|||||||
@ -238,6 +238,9 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
def is_contiguous(self, *arg, **kwargs):
|
def is_contiguous(self, *arg, **kwargs):
|
||||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||||
|
|
||||||
|
def storage(self):
|
||||||
|
return self._qdata.storage()
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Generic Utilities (Layout-Agnostic Operations)
|
# Generic Utilities (Layout-Agnostic Operations)
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -249,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn):
|
|||||||
|
|
||||||
|
|
||||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||||
if target_dtype is not None and target_dtype != qt.dtype:
|
|
||||||
logging.warning(
|
|
||||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
|
||||||
f"but not supported for quantized tensors. Ignoring dtype."
|
|
||||||
)
|
|
||||||
|
|
||||||
if target_layout is not None and target_layout != torch.strided:
|
if target_layout is not None and target_layout != torch.strided:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||||
@ -274,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
|||||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||||
new_q_data = qt._qdata.to(device=target_device)
|
new_q_data = qt._qdata.to(device=target_device)
|
||||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
if target_dtype is not None:
|
||||||
|
new_params["orig_dtype"] = target_dtype
|
||||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||||
return new_qt
|
return new_qt
|
||||||
@ -339,7 +338,9 @@ def generic_copy_(func, args, kwargs):
|
|||||||
# Copy from another quantized tensor
|
# Copy from another quantized tensor
|
||||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||||
qt_dest._layout_type = src._layout_type
|
qt_dest._layout_type = src._layout_type
|
||||||
|
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
||||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||||
|
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
||||||
else:
|
else:
|
||||||
# Copy from regular tensor - just copy raw data
|
# Copy from regular tensor - just copy raw data
|
||||||
qt_dest._qdata.copy_(src)
|
qt_dest._qdata.copy_(src)
|
||||||
@ -397,9 +398,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||||
orig_dtype = tensor.dtype
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
if scale is None:
|
if isinstance(scale, str) and scale == "recalculate":
|
||||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||||
|
|
||||||
|
if scale is not None:
|
||||||
if not isinstance(scale, torch.Tensor):
|
if not isinstance(scale, torch.Tensor):
|
||||||
scale = torch.tensor(scale)
|
scale = torch.tensor(scale)
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
@ -408,6 +410,8 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||||
else:
|
else:
|
||||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
|
else:
|
||||||
|
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
if stochastic_rounding > 0:
|
if stochastic_rounding > 0:
|
||||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||||
|
|||||||
144
comfy/sd.py
144
comfy/sd.py
@ -53,6 +53,8 @@ import comfy.text_encoders.omnigen2
|
|||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
import comfy.text_encoders.hunyuan_image
|
import comfy.text_encoders.hunyuan_image
|
||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
|
import comfy.text_encoders.ovis
|
||||||
|
import comfy.text_encoders.kandinsky5
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -97,7 +99,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@ -128,6 +130,27 @@ class CLIP:
|
|||||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
self.patcher.is_clip = True
|
self.patcher.is_clip = True
|
||||||
self.apply_hooks_to_conds = None
|
self.apply_hooks_to_conds = None
|
||||||
|
if len(state_dict) > 0:
|
||||||
|
if isinstance(state_dict, list):
|
||||||
|
for c in state_dict:
|
||||||
|
m, u = self.load_sd(c)
|
||||||
|
if len(m) > 0:
|
||||||
|
logging.warning("clip missing: {}".format(m))
|
||||||
|
|
||||||
|
if len(u) > 0:
|
||||||
|
logging.debug("clip unexpected: {}".format(u))
|
||||||
|
else:
|
||||||
|
m, u = self.load_sd(state_dict, full_model=True)
|
||||||
|
if len(m) > 0:
|
||||||
|
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||||
|
if len(m_filter) > 0:
|
||||||
|
logging.warning("clip missing: {}".format(m))
|
||||||
|
else:
|
||||||
|
logging.debug("clip missing: {}".format(m))
|
||||||
|
|
||||||
|
if len(u) > 0:
|
||||||
|
logging.debug("clip unexpected {}:".format(u))
|
||||||
|
|
||||||
if params['device'] == load_device:
|
if params['device'] == load_device:
|
||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
@ -192,6 +215,7 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||||
all_hooks.reset()
|
all_hooks.reset()
|
||||||
self.patcher.patch_hooks(None)
|
self.patcher.patch_hooks(None)
|
||||||
if show_pbar:
|
if show_pbar:
|
||||||
@ -239,6 +263,7 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
cond, pooled = o[:2]
|
cond, pooled = o[:2]
|
||||||
if return_dict:
|
if return_dict:
|
||||||
@ -468,7 +493,7 @@ class VAE:
|
|||||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||||
|
|
||||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||||
elif "decoder.conv_in.conv.weight" in sd:
|
elif "decoder.conv_in.conv.weight" in sd:
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
ddconfig["conv3d"] = True
|
ddconfig["conv3d"] = True
|
||||||
@ -480,8 +505,10 @@ class VAE:
|
|||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||||
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
#This is likely to significantly over-estimate with single image or low frame counts as the
|
||||||
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
#implementation is able to completely skip caching. Rework if used as an image only VAE
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
elif "decoder.unpatcher3d.wavelets" in sd:
|
elif "decoder.unpatcher3d.wavelets" in sd:
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
||||||
@ -740,6 +767,8 @@ class VAE:
|
|||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
do_tile = False
|
do_tile = False
|
||||||
|
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||||
|
samples_in = samples_in[:, :, 0]
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
@ -956,16 +985,17 @@ class CLIPType(Enum):
|
|||||||
QWEN_IMAGE = 18
|
QWEN_IMAGE = 18
|
||||||
HUNYUAN_IMAGE = 19
|
HUNYUAN_IMAGE = 19
|
||||||
HUNYUAN_VIDEO_15 = 20
|
HUNYUAN_VIDEO_15 = 20
|
||||||
|
OVIS = 21
|
||||||
|
KANDINSKY5 = 22
|
||||||
|
KANDINSKY5_IMAGE = 23
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||||
if metadata is not None:
|
if model_options.get("custom_operations", None) is None:
|
||||||
quant_metadata = metadata.get("_quantization_metadata", None)
|
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||||
if quant_metadata is not None:
|
|
||||||
sd["_quantization_metadata"] = quant_metadata
|
|
||||||
clip_data.append(sd)
|
clip_data.append(sd)
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
@ -987,6 +1017,7 @@ class TEModel(Enum):
|
|||||||
MISTRAL3_24B = 14
|
MISTRAL3_24B = 14
|
||||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||||
QWEN3_4B = 16
|
QWEN3_4B = 16
|
||||||
|
QWEN3_2B = 17
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1020,9 +1051,12 @@ def detect_te_model(sd):
|
|||||||
if weight.shape[0] == 512:
|
if weight.shape[0] == 512:
|
||||||
return TEModel.QWEN25_7B
|
return TEModel.QWEN25_7B
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
|
||||||
return TEModel.QWEN3_4B
|
|
||||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||||
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
|
if weight.shape[0] == 2560:
|
||||||
|
return TEModel.QWEN3_4B
|
||||||
|
elif weight.shape[0] == 2048:
|
||||||
|
return TEModel.QWEN3_2B
|
||||||
if weight.shape[0] == 5120:
|
if weight.shape[0] == 5120:
|
||||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||||
return TEModel.MISTRAL3_24B
|
return TEModel.MISTRAL3_24B
|
||||||
@ -1078,7 +1112,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
@ -1102,7 +1136,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
@ -1131,7 +1165,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif te_model == TEModel.LLAMA3_8:
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
elif te_model == TEModel.QWEN25_3B:
|
elif te_model == TEModel.QWEN25_3B:
|
||||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||||
@ -1150,13 +1184,16 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif te_model == TEModel.QWEN3_4B:
|
elif te_model == TEModel.QWEN3_4B:
|
||||||
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||||
|
elif te_model == TEModel.QWEN3_2B:
|
||||||
|
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
@ -1199,6 +1236,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
||||||
|
elif clip_type == CLIPType.KANDINSKY5:
|
||||||
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
|
||||||
|
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||||
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
@ -1211,19 +1254,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
|
|
||||||
parameters = 0
|
parameters = 0
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
if "_quantization_metadata" in c:
|
|
||||||
c.pop("_quantization_metadata")
|
|
||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
||||||
for c in clip_data:
|
|
||||||
m, u = clip.load_sd(c)
|
|
||||||
if len(m) > 0:
|
|
||||||
logging.warning("clip missing: {}".format(m))
|
|
||||||
|
|
||||||
if len(u) > 0:
|
|
||||||
logging.debug("clip unexpected: {}".format(u))
|
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
@ -1282,6 +1316,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
|
custom_operations = model_options.get("custom_operations", None)
|
||||||
|
if custom_operations is None:
|
||||||
|
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||||
@ -1290,17 +1328,21 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return None
|
return None
|
||||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||||
|
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.quant_config is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
|
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
if custom_operations is not None:
|
||||||
|
model_config.custom_operations = custom_operations
|
||||||
|
|
||||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||||
|
|
||||||
if unet_dtype is None:
|
if unet_dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||||
|
|
||||||
|
if model_config.quant_config is not None:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
|
else:
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
@ -1319,22 +1361,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
|
if te_model_options.get("custom_operations", None) is None:
|
||||||
|
scaled_fp8_list = []
|
||||||
|
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
|
||||||
|
if k.endswith(".scaled_fp8"):
|
||||||
|
scaled_fp8_list.append(k[:-len("scaled_fp8")])
|
||||||
|
|
||||||
|
if len(scaled_fp8_list) > 0:
|
||||||
|
out_sd = {}
|
||||||
|
for k in sd:
|
||||||
|
skip = False
|
||||||
|
for pref in scaled_fp8_list:
|
||||||
|
skip = skip or k.startswith(pref)
|
||||||
|
if not skip:
|
||||||
|
out_sd[k] = sd[k]
|
||||||
|
|
||||||
|
for pref in scaled_fp8_list:
|
||||||
|
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||||
|
for k in quant_sd:
|
||||||
|
out_sd[k] = quant_sd[k]
|
||||||
|
sd = out_sd
|
||||||
|
|
||||||
clip_target = model_config.clip_target(state_dict=sd)
|
clip_target = model_config.clip_target(state_dict=sd)
|
||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
|
||||||
if len(m) > 0:
|
|
||||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
|
||||||
if len(m_filter) > 0:
|
|
||||||
logging.warning("clip missing: {}".format(m))
|
|
||||||
else:
|
|
||||||
logging.debug("clip missing: {}".format(m))
|
|
||||||
|
|
||||||
if len(u) > 0:
|
|
||||||
logging.debug("clip unexpected {}:".format(u))
|
|
||||||
else:
|
else:
|
||||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||||
|
|
||||||
@ -1381,6 +1434,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
if len(temp_sd) > 0:
|
if len(temp_sd) > 0:
|
||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
|
|
||||||
|
custom_operations = model_options.get("custom_operations", None)
|
||||||
|
if custom_operations is None:
|
||||||
|
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
@ -1411,7 +1467,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.quant_config is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
@ -1419,12 +1475,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
if model_config.layer_quant_config is not None:
|
if model_config.quant_config is not None:
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
else:
|
else:
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
|
||||||
|
if custom_operations is not None:
|
||||||
|
model_config.custom_operations = custom_operations
|
||||||
|
|
||||||
if model_options.get("fp8_optimizations", False):
|
if model_options.get("fp8_optimizations", False):
|
||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
@ -1463,6 +1522,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||||||
if vae is not None:
|
if vae is not None:
|
||||||
vae_sd = vae.get_sd()
|
vae_sd = vae.get_sd()
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||||
|
|||||||
@ -107,29 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
config[k] = v
|
config[k] = v
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
scaled_fp8 = None
|
quant_config = model_options.get("quantization_metadata", None)
|
||||||
quantization_metadata = model_options.get("quantization_metadata", None)
|
|
||||||
|
|
||||||
if operations is None:
|
if operations is None:
|
||||||
layer_quant_config = None
|
if quant_config is not None:
|
||||||
if quantization_metadata is not None:
|
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||||
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
|
logging.info("Using MixedPrecisionOps for text encoder")
|
||||||
|
|
||||||
if layer_quant_config is not None:
|
|
||||||
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
|
|
||||||
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
|
|
||||||
else:
|
|
||||||
# Fallback to scaled_fp8_ops for backward compatibility
|
|
||||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
|
||||||
if scaled_fp8 is not None:
|
|
||||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
|
||||||
else:
|
else:
|
||||||
operations = comfy.ops.manual_cast
|
operations = comfy.ops.manual_cast
|
||||||
|
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
if scaled_fp8 is not None:
|
|
||||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
|
||||||
|
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
@ -147,6 +135,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
self.return_projected_pooled = return_projected_pooled
|
self.return_projected_pooled = return_projected_pooled
|
||||||
self.return_attention_masks = return_attention_masks
|
self.return_attention_masks = return_attention_masks
|
||||||
|
self.execution_device = None
|
||||||
|
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
assert layer_idx is not None
|
assert layer_idx is not None
|
||||||
@ -163,6 +152,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
layer_idx = options.get("layer", self.layer_idx)
|
layer_idx = options.get("layer", self.layer_idx)
|
||||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||||
|
self.execution_device = options.get("execution_device", self.execution_device)
|
||||||
if isinstance(self.layer, list) or self.layer == "all":
|
if isinstance(self.layer, list) or self.layer == "all":
|
||||||
pass
|
pass
|
||||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||||
@ -175,6 +165,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.layer = self.options_default[0]
|
self.layer = self.options_default[0]
|
||||||
self.layer_idx = self.options_default[1]
|
self.layer_idx = self.options_default[1]
|
||||||
self.return_projected_pooled = self.options_default[2]
|
self.return_projected_pooled = self.options_default[2]
|
||||||
|
self.execution_device = None
|
||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
def process_tokens(self, tokens, device):
|
||||||
end_token = self.special_tokens.get("end", None)
|
end_token = self.special_tokens.get("end", None)
|
||||||
@ -258,7 +249,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
|
if self.execution_device is None:
|
||||||
device = self.transformer.get_input_embeddings().weight.device
|
device = self.transformer.get_input_embeddings().weight.device
|
||||||
|
else:
|
||||||
|
device = self.execution_device
|
||||||
|
|
||||||
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
||||||
|
|
||||||
attention_mask_model = None
|
attention_mask_model = None
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import comfy.text_encoders.ace
|
|||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
import comfy.text_encoders.hunyuan_image
|
import comfy.text_encoders.hunyuan_image
|
||||||
|
import comfy.text_encoders.kandinsky5
|
||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
@ -1027,6 +1028,8 @@ class ZImage(Lumina2):
|
|||||||
|
|
||||||
memory_usage_factor = 1.7
|
memory_usage_factor = 1.7
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||||
@ -1472,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
|
|
||||||
|
|
||||||
|
class Kandinsky5(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "kandinsky5",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 10.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.HunyuanVideo
|
||||||
|
|
||||||
|
memory_usage_factor = 1.1 #TODO
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Kandinsky5(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5Image(Kandinsky5):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "kandinsky5",
|
||||||
|
"model_dim": 2560,
|
||||||
|
"visual_embed_dim": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.Flux
|
||||||
|
memory_usage_factor = 1.1 #TODO
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Kandinsky5Image(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
from . import model_base
|
from . import model_base
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -49,8 +50,7 @@ class BASE:
|
|||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
scaled_fp8 = None
|
quant_config = None # quantization configuration for mixed precision
|
||||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
|
||||||
optimizations = {"fp8": False}
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -118,3 +118,7 @@ class BASE:
|
|||||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||||
self.unet_config['dtype'] = dtype
|
self.unet_config['dtype'] = dtype
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
|
||||||
|
return None
|
||||||
|
|||||||
@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
|
|||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||||
if t5xxl_scaled_fp8 is not None:
|
if t5xxl_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class CosmosTEModel_(CosmosT5XXL):
|
class CosmosTEModel_(CosmosT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class FluxClipModel_(FluxClipModel):
|
class FluxClipModel_(FluxClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return FluxClipModel_
|
return FluxClipModel_
|
||||||
|
|
||||||
@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
|
|||||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||||
return out, pooled, extra
|
return out, pooled, extra
|
||||||
|
|
||||||
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
|
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||||
class Flux2TEModel_(Flux2TEModel):
|
class Flux2TEModel_(Flux2TEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if pruned:
|
if pruned:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
|
|||||||
@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class MochiTEModel_(MochiT5XXL):
|
class MochiTEModel_(MochiT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
return self.llama.load_sd(sd)
|
return self.llama.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
|
||||||
class HiDreamTEModel_(HiDreamTEModel):
|
class HiDreamTEModel_(HiDreamTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
return HiDreamTEModel_
|
return HiDreamTEModel_
|
||||||
|
|||||||
@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
|||||||
|
|
||||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
if llama_scaled_fp8 is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
|||||||
else:
|
else:
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
|
|
||||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import numbers
|
import numbers
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
def llama_detect(state_dict, prefix=""):
|
def llama_detect(state_dict, prefix=""):
|
||||||
out = {}
|
out = {}
|
||||||
@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""):
|
|||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||||
if scaled_fp8_key in state_dict:
|
if quant is not None:
|
||||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
out["llama_quantization_metadata"] = quant
|
||||||
|
|
||||||
if "_quantization_metadata" in state_dict:
|
|
||||||
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class LLAMAModel(sd1_clip.SDClipModel):
|
class LLAMAModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
if llama_scaled_fp8 is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
|
||||||
textmodel_json_config = {}
|
textmodel_json_config = {}
|
||||||
vocab_size = model_options.get("vocab_size", None)
|
vocab_size = model_options.get("vocab_size", None)
|
||||||
@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
|||||||
return self.llama.load_sd(sd)
|
return self.llama.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
return HunyuanVideoClipModel_
|
return HunyuanVideoClipModel_
|
||||||
|
|||||||
68
comfy/text_encoders/kandinsky5.py
Normal file
68
comfy/text_encoders/kandinsky5.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||||
|
from .llama import Qwen25_7BVLI
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5Tokenizer(QwenImageTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||||
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
|
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5TEModel(QwenImageTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||||
|
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
|
||||||
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
|
||||||
|
|
||||||
|
return cond, l_pooled, extra
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
super().set_clip_options(options)
|
||||||
|
self.clip_l.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
super().reset_clip_options()
|
||||||
|
self.clip_l.reset_clip_options()
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_l.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return super().load_sd(sd)
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class Kandinsky5TEModel_(Kandinsky5TEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return Kandinsky5TEModel_
|
||||||
@ -100,6 +100,28 @@ class Qwen3_4BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Ovis25_2BConfig:
|
||||||
|
vocab_size: int = 151936
|
||||||
|
hidden_size: int = 2048
|
||||||
|
intermediate_size: int = 6144
|
||||||
|
num_hidden_layers: int = 28
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 40960
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 1000000.0
|
||||||
|
transformer_type: str = "llama"
|
||||||
|
head_dim = 128
|
||||||
|
rms_norm_add = False
|
||||||
|
mlp_activation = "silu"
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
q_norm = "gemma3"
|
||||||
|
k_norm = "gemma3"
|
||||||
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_7BVLI_Config:
|
class Qwen25_7BVLI_Config:
|
||||||
vocab_size: int = 152064
|
vocab_size: int = 152064
|
||||||
@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Ovis25_2B(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Ovis25_2BConfig(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
|
||||||
if model_type == "gemma2_2b":
|
if model_type == "gemma2_2b":
|
||||||
model = Gemma2_2BModel
|
model = Gemma2_2BModel
|
||||||
elif model_type == "gemma3_4b":
|
elif model_type == "gemma3_4b":
|
||||||
@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
|||||||
|
|
||||||
class LuminaTEModel_(LuminaModel):
|
class LuminaTEModel_(LuminaModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
||||||
|
|||||||
@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class Omnigen2TEModel_(Omnigen2Model):
|
class Omnigen2TEModel_(Omnigen2Model):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
66
comfy/text_encoders/ovis.py
Normal file
66
comfy/text_encoders/ovis.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from transformers import Qwen2Tokenizer
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
|
class OvisTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
|
||||||
|
self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||||
|
if llama_template is None:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
|
||||||
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
class Ovis25_2BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class OvisTEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
||||||
|
out, pooled = super().encode_token_weights(token_weight_pairs)
|
||||||
|
tok_pairs = token_weight_pairs["qwen3_2b"][0]
|
||||||
|
count_im_start = 0
|
||||||
|
if template_end == -1:
|
||||||
|
for i, v in enumerate(tok_pairs):
|
||||||
|
elem = v[0]
|
||||||
|
if not torch.is_tensor(elem):
|
||||||
|
if isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 4004 and count_im_start < 1:
|
||||||
|
template_end = i
|
||||||
|
count_im_start += 1
|
||||||
|
|
||||||
|
if out.shape[1] > (template_end + 1):
|
||||||
|
if tok_pairs[template_end + 1][0] == 25:
|
||||||
|
template_end += 1
|
||||||
|
|
||||||
|
out = out[:, template_end:]
|
||||||
|
return out, pooled, {}
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class OvisTEModel_(OvisTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return OvisTEModel_
|
||||||
@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class PixArtTEModel_(PixArtT5XXL):
|
class PixArtTEModel_(PixArtT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -179,36 +179,36 @@
|
|||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
"151665": {
|
"151665": {
|
||||||
"content": "<|img|>",
|
"content": "<tool_response>",
|
||||||
"lstrip": false,
|
"lstrip": false,
|
||||||
"normalized": false,
|
"normalized": false,
|
||||||
"rstrip": false,
|
"rstrip": false,
|
||||||
"single_word": false,
|
"single_word": false,
|
||||||
"special": true
|
"special": false
|
||||||
},
|
},
|
||||||
"151666": {
|
"151666": {
|
||||||
"content": "<|endofimg|>",
|
"content": "</tool_response>",
|
||||||
"lstrip": false,
|
"lstrip": false,
|
||||||
"normalized": false,
|
"normalized": false,
|
||||||
"rstrip": false,
|
"rstrip": false,
|
||||||
"single_word": false,
|
"single_word": false,
|
||||||
"special": true
|
"special": false
|
||||||
},
|
},
|
||||||
"151667": {
|
"151667": {
|
||||||
"content": "<|meta|>",
|
"content": "<think>",
|
||||||
"lstrip": false,
|
"lstrip": false,
|
||||||
"normalized": false,
|
"normalized": false,
|
||||||
"rstrip": false,
|
"rstrip": false,
|
||||||
"single_word": false,
|
"single_word": false,
|
||||||
"special": true
|
"special": false
|
||||||
},
|
},
|
||||||
"151668": {
|
"151668": {
|
||||||
"content": "<|endofmeta|>",
|
"content": "</think>",
|
||||||
"lstrip": false,
|
"lstrip": false,
|
||||||
"normalized": false,
|
"normalized": false,
|
||||||
"rstrip": false,
|
"rstrip": false,
|
||||||
"single_word": false,
|
"single_word": false,
|
||||||
"special": true
|
"special": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additional_special_tokens": [
|
"additional_special_tokens": [
|
||||||
|
|||||||
@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
|||||||
return out, pooled, extra
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class QwenImageTEModel_(QwenImageTEModel):
|
class QwenImageTEModel_(QwenImageTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -6,14 +6,15 @@ import torch
|
|||||||
import os
|
import os
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import logging
|
import logging
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||||
if t5xxl_scaled_fp8 is not None:
|
if t5xxl_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||||
|
|
||||||
model_options = {**model_options, "model_name": "t5xxl"}
|
model_options = {**model_options, "model_name": "t5xxl"}
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||||
if scaled_fp8_key in state_dict:
|
if quant is not None:
|
||||||
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
out["t5_quantization_metadata"] = quant
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
|
||||||
class SD3ClipModel_(SD3ClipModel):
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
|
|||||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
||||||
|
|
||||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class WanTEModel(WanT5Model):
|
class WanTEModel(WanT5Model):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype_t5 is not None:
|
if dtype_t5 is not None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class ZImageTEModel_(ZImageTEModel):
|
class ZImageTEModel_(ZImageTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
|
|||||||
@ -29,6 +29,7 @@ import itertools
|
|||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import json
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
DISABLE_MMAP = args.disable_mmap
|
DISABLE_MMAP = args.disable_mmap
|
||||||
@ -1194,3 +1195,68 @@ def unpack_latents(combined_latent, latent_shapes):
|
|||||||
else:
|
else:
|
||||||
output_tensors = combined_latent
|
output_tensors = combined_latent
|
||||||
return output_tensors
|
return output_tensors
|
||||||
|
|
||||||
|
def detect_layer_quantization(state_dict, prefix):
|
||||||
|
for k in state_dict:
|
||||||
|
if k.startswith(prefix) and k.endswith(".comfy_quant"):
|
||||||
|
logging.info("Found quantization metadata version 1")
|
||||||
|
return {"mixed_ops": True}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
quant_metadata = None
|
||||||
|
if "_quantization_metadata" not in metadata:
|
||||||
|
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
|
||||||
|
|
||||||
|
if scaled_fp8_key in state_dict:
|
||||||
|
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||||
|
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||||
|
if scaled_fp8_dtype == torch.float32:
|
||||||
|
scaled_fp8_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if scaled_fp8_weight.nelement() == 2:
|
||||||
|
full_precision_matrix_mult = True
|
||||||
|
else:
|
||||||
|
full_precision_matrix_mult = False
|
||||||
|
|
||||||
|
out_sd = {}
|
||||||
|
layers = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if not k.startswith(model_prefix):
|
||||||
|
out_sd[k] = state_dict[k]
|
||||||
|
continue
|
||||||
|
k_out = k
|
||||||
|
w = state_dict.pop(k)
|
||||||
|
layer = None
|
||||||
|
if k_out.endswith(".scale_weight"):
|
||||||
|
layer = k_out[:-len(".scale_weight")]
|
||||||
|
k_out = "{}.weight_scale".format(layer)
|
||||||
|
|
||||||
|
if layer is not None:
|
||||||
|
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||||
|
if full_precision_matrix_mult:
|
||||||
|
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||||
|
layers[layer] = layer_conf
|
||||||
|
|
||||||
|
if k_out.endswith(".scale_input"):
|
||||||
|
layer = k_out[:-len(".scale_input")]
|
||||||
|
k_out = "{}.input_scale".format(layer)
|
||||||
|
if w.item() == 1.0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_sd[k_out] = w
|
||||||
|
|
||||||
|
state_dict = out_sd
|
||||||
|
quant_metadata = {"layers": layers}
|
||||||
|
else:
|
||||||
|
quant_metadata = json.loads(metadata["_quantization_metadata"])
|
||||||
|
|
||||||
|
if quant_metadata is not None:
|
||||||
|
layers = quant_metadata["layers"]
|
||||||
|
for k, v in layers.items():
|
||||||
|
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
||||||
|
|
||||||
|
return state_dict, metadata
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from comfy.cli_args import args
|
|||||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||||
"supports_preview_metadata": True,
|
"supports_preview_metadata": True,
|
||||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||||
|
"extension": {"manager": {"supports_v4": True}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
|
|||||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||||
from . import _io as io
|
from . import _io_public as io
|
||||||
from . import _ui as ui
|
from . import _ui_public as ui
|
||||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
|
|||||||
@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput):
|
|||||||
raise ValueError("Only MP4 format is supported for now")
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||||
raise ValueError("Only H264 codec is supported for now")
|
raise ValueError("Only H264 codec is supported for now")
|
||||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
extra_kwargs = {}
|
||||||
|
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||||
|
extra_kwargs["format"] = format.value
|
||||||
|
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
|
||||||
# Add metadata before writing any streams
|
# Add metadata before writing any streams
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
|
|||||||
@ -4,7 +4,8 @@ import copy
|
|||||||
import inspect
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from dataclasses import asdict, dataclass
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
|
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
|
||||||
from typing_extensions import NotRequired, final
|
from typing_extensions import NotRequired, final
|
||||||
@ -150,6 +151,9 @@ class _IO_V3:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def io_type(self):
|
def io_type(self):
|
||||||
return self.Parent.io_type
|
return self.Parent.io_type
|
||||||
@ -182,6 +186,9 @@ class Input(_IO_V3):
|
|||||||
def get_io_type(self):
|
def get_io_type(self):
|
||||||
return _StringIOType(self.io_type)
|
return _StringIOType(self.io_type)
|
||||||
|
|
||||||
|
def get_all(self) -> list[Input]:
|
||||||
|
return [self]
|
||||||
|
|
||||||
class WidgetInput(Input):
|
class WidgetInput(Input):
|
||||||
'''
|
'''
|
||||||
Base class for a V3 Input with widget.
|
Base class for a V3 Input with widget.
|
||||||
@ -561,6 +568,8 @@ class Conditioning(ComfyTypeIO):
|
|||||||
'''Used by WAN Camera.'''
|
'''Used by WAN Camera.'''
|
||||||
time_dim_concat: NotRequired[torch.Tensor]
|
time_dim_concat: NotRequired[torch.Tensor]
|
||||||
'''Used by WAN Phantom Subject.'''
|
'''Used by WAN Phantom Subject.'''
|
||||||
|
time_dim_replace: NotRequired[torch.Tensor]
|
||||||
|
'''Used by Kandinsky5 I2V.'''
|
||||||
|
|
||||||
CondList = list[tuple[torch.Tensor, PooledDict]]
|
CondList = list[tuple[torch.Tensor, PooledDict]]
|
||||||
Type = CondList
|
Type = CondList
|
||||||
@ -814,13 +823,61 @@ class MultiType:
|
|||||||
else:
|
else:
|
||||||
return super().as_dict()
|
return super().as_dict()
|
||||||
|
|
||||||
|
@comfytype(io_type="COMFY_MATCHTYPE_V3")
|
||||||
|
class MatchType(ComfyTypeIO):
|
||||||
|
class Template:
|
||||||
|
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
|
||||||
|
self.template_id = template_id
|
||||||
|
# account for syntactic sugar
|
||||||
|
if not isinstance(allowed_types, Iterable):
|
||||||
|
allowed_types = [allowed_types]
|
||||||
|
for t in allowed_types:
|
||||||
|
if not isinstance(t, type):
|
||||||
|
if not isinstance(t, _ComfyType):
|
||||||
|
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
|
||||||
|
else:
|
||||||
|
if not issubclass(t, _ComfyType):
|
||||||
|
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
|
||||||
|
self.allowed_types = allowed_types
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {
|
||||||
|
"template_id": self.template_id,
|
||||||
|
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
|
||||||
|
}
|
||||||
|
|
||||||
|
class Input(Input):
|
||||||
|
def __init__(self, id: str, template: MatchType.Template,
|
||||||
|
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
|
self.template = template
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"template": self.template.as_dict(),
|
||||||
|
})
|
||||||
|
|
||||||
|
class Output(Output):
|
||||||
|
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
||||||
|
is_output_list=False):
|
||||||
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
|
self.template = template
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"template": self.template.as_dict(),
|
||||||
|
})
|
||||||
|
|
||||||
class DynamicInput(Input, ABC):
|
class DynamicInput(Input, ABC):
|
||||||
'''
|
'''
|
||||||
Abstract class for dynamic input registration.
|
Abstract class for dynamic input registration.
|
||||||
'''
|
'''
|
||||||
@abstractmethod
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
def get_dynamic(self) -> list[Input]:
|
||||||
...
|
return []
|
||||||
|
|
||||||
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DynamicOutput(Output, ABC):
|
class DynamicOutput(Output, ABC):
|
||||||
'''
|
'''
|
||||||
@ -830,99 +887,223 @@ class DynamicOutput(Output, ABC):
|
|||||||
is_output_list=False):
|
is_output_list=False):
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
super().__init__(id, display_name, tooltip, is_output_list)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dynamic(self) -> list[Output]:
|
def get_dynamic(self) -> list[Output]:
|
||||||
...
|
return []
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
||||||
class AutogrowDynamic(ComfyTypeI):
|
class Autogrow(ComfyTypeI):
|
||||||
Type = list[Any]
|
Type = dict[str, Any]
|
||||||
class Input(DynamicInput):
|
_MaxNames = 100 # NOTE: max 100 names for sanity
|
||||||
def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
|
|
||||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
class _AutogrowTemplate:
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
def __init__(self, input: Input):
|
||||||
self.template_input = template_input
|
# dynamic inputs are not allowed as the template input
|
||||||
if min is not None:
|
assert(not isinstance(input, DynamicInput))
|
||||||
assert(min >= 1)
|
self.input = copy.copy(input)
|
||||||
if max is not None:
|
if isinstance(self.input, WidgetInput):
|
||||||
|
self.input.force_input = True
|
||||||
|
self.names: list[str] = []
|
||||||
|
self.cached_inputs = {}
|
||||||
|
|
||||||
|
def _create_input(self, input: Input, name: str):
|
||||||
|
new_input = copy.copy(self.input)
|
||||||
|
new_input.id = name
|
||||||
|
return new_input
|
||||||
|
|
||||||
|
def _create_cached_inputs(self):
|
||||||
|
for name in self.names:
|
||||||
|
self.cached_inputs[name] = self._create_input(self.input, name)
|
||||||
|
|
||||||
|
def get_all(self) -> list[Input]:
|
||||||
|
return list(self.cached_inputs.values())
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return prune_dict({
|
||||||
|
"input": create_input_dict_v1([self.input]),
|
||||||
|
})
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
self.input.validate()
|
||||||
|
|
||||||
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
|
real_inputs = []
|
||||||
|
for name, input in self.cached_inputs.items():
|
||||||
|
if name in live_inputs:
|
||||||
|
real_inputs.append(input)
|
||||||
|
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
|
||||||
|
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
|
||||||
|
|
||||||
|
class TemplatePrefix(_AutogrowTemplate):
|
||||||
|
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
||||||
|
super().__init__(input)
|
||||||
|
self.prefix = prefix
|
||||||
|
assert(min >= 0)
|
||||||
assert(max >= 1)
|
assert(max >= 1)
|
||||||
|
assert(max <= Autogrow._MaxNames)
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
|
self.names = [f"{self.prefix}{i}" for i in range(self.max)]
|
||||||
|
self._create_cached_inputs()
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"prefix": self.prefix,
|
||||||
|
"min": self.min,
|
||||||
|
"max": self.max,
|
||||||
|
})
|
||||||
|
|
||||||
|
class TemplateNames(_AutogrowTemplate):
|
||||||
|
def __init__(self, input: Input, names: list[str], min: int=1):
|
||||||
|
super().__init__(input)
|
||||||
|
self.names = names[:Autogrow._MaxNames]
|
||||||
|
assert(min >= 0)
|
||||||
|
self.min = min
|
||||||
|
self._create_cached_inputs()
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"names": self.names,
|
||||||
|
"min": self.min,
|
||||||
|
})
|
||||||
|
|
||||||
|
class Input(DynamicInput):
|
||||||
|
def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
|
||||||
|
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
|
self.template = template
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({
|
||||||
|
"template": self.template.as_dict(),
|
||||||
|
})
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
def get_dynamic(self) -> list[Input]:
|
||||||
curr_count = 1
|
return self.template.get_all()
|
||||||
new_inputs = []
|
|
||||||
for i in range(self.min):
|
|
||||||
new_input = copy.copy(self.template_input)
|
|
||||||
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
|
|
||||||
if new_input.display_name is not None:
|
|
||||||
new_input.display_name = f"{new_input.display_name}{curr_count}"
|
|
||||||
new_input.optional = self.optional or new_input.optional
|
|
||||||
if isinstance(self.template_input, WidgetInput):
|
|
||||||
new_input.force_input = True
|
|
||||||
new_inputs.append(new_input)
|
|
||||||
curr_count += 1
|
|
||||||
# pretend to expand up to max
|
|
||||||
for i in range(curr_count-1, self.max):
|
|
||||||
new_input = copy.copy(self.template_input)
|
|
||||||
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
|
|
||||||
if new_input.display_name is not None:
|
|
||||||
new_input.display_name = f"{new_input.display_name}{curr_count}"
|
|
||||||
new_input.optional = True
|
|
||||||
if isinstance(self.template_input, WidgetInput):
|
|
||||||
new_input.force_input = True
|
|
||||||
new_inputs.append(new_input)
|
|
||||||
curr_count += 1
|
|
||||||
return new_inputs
|
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
|
def get_all(self) -> list[Input]:
|
||||||
class ComboDynamic(ComfyTypeI):
|
return [self] + self.template.get_all()
|
||||||
class Input(DynamicInput):
|
|
||||||
def __init__(self, id: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_MATCHTYPE_V3")
|
def validate(self):
|
||||||
class MatchType(ComfyTypeIO):
|
self.template.validate()
|
||||||
class Template:
|
|
||||||
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
self.template_id = template_id
|
curr_prefix = f"{curr_prefix}{self.id}."
|
||||||
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
|
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
|
||||||
|
for inner_dict in d.values():
|
||||||
|
if self.id in inner_dict:
|
||||||
|
del inner_dict[self.id]
|
||||||
|
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||||
|
|
||||||
|
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||||
|
class DynamicCombo(ComfyTypeI):
|
||||||
|
Type = dict[str, Any]
|
||||||
|
|
||||||
|
class Option:
|
||||||
|
def __init__(self, key: str, inputs: list[Input]):
|
||||||
|
self.key = key
|
||||||
|
self.inputs = inputs
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return {
|
return {
|
||||||
"template_id": self.template_id,
|
"key": self.key,
|
||||||
"allowed_types": "".join(t.io_type for t in self.allowed_types),
|
"inputs": create_input_dict_v1(self.inputs),
|
||||||
}
|
}
|
||||||
|
|
||||||
class Input(DynamicInput):
|
class Input(DynamicInput):
|
||||||
def __init__(self, id: str, template: MatchType.Template,
|
def __init__(self, id: str, options: list[DynamicCombo.Option],
|
||||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||||
self.template = template
|
self.options = options
|
||||||
|
|
||||||
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
|
# check if dynamic input's id is in live_inputs
|
||||||
|
if self.id in live_inputs:
|
||||||
|
curr_prefix = f"{curr_prefix}{self.id}."
|
||||||
|
key = live_inputs[self.id]
|
||||||
|
selected_option = None
|
||||||
|
for option in self.options:
|
||||||
|
if option.key == key:
|
||||||
|
selected_option = option
|
||||||
|
break
|
||||||
|
if selected_option is not None:
|
||||||
|
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
|
||||||
|
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Input]:
|
def get_dynamic(self) -> list[Input]:
|
||||||
return [self]
|
return [input for option in self.options for input in option.inputs]
|
||||||
|
|
||||||
|
def get_all(self) -> list[Input]:
|
||||||
|
return [self] + [input for option in self.options for input in option.inputs]
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return super().as_dict() | prune_dict({
|
return super().as_dict() | prune_dict({
|
||||||
"template": self.template.as_dict(),
|
"options": [o.as_dict() for o in self.options],
|
||||||
})
|
})
|
||||||
|
|
||||||
class Output(DynamicOutput):
|
def validate(self):
|
||||||
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
|
# make sure all nested inputs are validated
|
||||||
is_output_list=False):
|
for option in self.options:
|
||||||
super().__init__(id, display_name, tooltip, is_output_list)
|
for input in option.inputs:
|
||||||
self.template = template
|
input.validate()
|
||||||
|
|
||||||
def get_dynamic(self) -> list[Output]:
|
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
||||||
return [self]
|
class DynamicSlot(ComfyTypeI):
|
||||||
|
Type = dict[str, Any]
|
||||||
|
|
||||||
|
class Input(DynamicInput):
|
||||||
|
def __init__(self, slot: Input, inputs: list[Input],
|
||||||
|
display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||||
|
assert(not isinstance(slot, DynamicInput))
|
||||||
|
self.slot = copy.copy(slot)
|
||||||
|
self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
|
||||||
|
optional = True
|
||||||
|
self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
|
||||||
|
self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
|
||||||
|
self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
|
||||||
|
super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
|
||||||
|
self.inputs = inputs
|
||||||
|
self.force_input = None
|
||||||
|
# force widget inputs to have no widgets, otherwise this would be awkward
|
||||||
|
if isinstance(self.slot, WidgetInput):
|
||||||
|
self.force_input = True
|
||||||
|
self.slot.force_input = True
|
||||||
|
|
||||||
|
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||||
|
if self.id in live_inputs:
|
||||||
|
curr_prefix = f"{curr_prefix}{self.id}."
|
||||||
|
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
|
||||||
|
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
|
||||||
|
|
||||||
|
def get_dynamic(self) -> list[Input]:
|
||||||
|
return [self.slot] + self.inputs
|
||||||
|
|
||||||
|
def get_all(self) -> list[Input]:
|
||||||
|
return [self] + [self.slot] + self.inputs
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return super().as_dict() | prune_dict({
|
return super().as_dict() | prune_dict({
|
||||||
"template": self.template.as_dict(),
|
"slotType": str(self.slot.get_io_type()),
|
||||||
|
"inputs": create_input_dict_v1(self.inputs),
|
||||||
|
"forceInput": self.force_input,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
self.slot.validate()
|
||||||
|
for input in self.inputs:
|
||||||
|
input.validate()
|
||||||
|
|
||||||
|
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
|
||||||
|
dynamic = d.setdefault("dynamic_paths", {})
|
||||||
|
if self is not None:
|
||||||
|
dynamic[self.id] = f"{curr_prefix}{self.id}"
|
||||||
|
for i in inputs:
|
||||||
|
if not isinstance(i, DynamicInput):
|
||||||
|
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
|
||||||
|
|
||||||
|
class V3Data(TypedDict):
|
||||||
|
hidden_inputs: dict[str, Any]
|
||||||
|
dynamic_paths: dict[str, Any]
|
||||||
|
|
||||||
class HiddenHolder:
|
class HiddenHolder:
|
||||||
def __init__(self, unique_id: str, prompt: Any,
|
def __init__(self, unique_id: str, prompt: Any,
|
||||||
@ -984,6 +1165,7 @@ class NodeInfoV1:
|
|||||||
output_is_list: list[bool]=None
|
output_is_list: list[bool]=None
|
||||||
output_name: list[str]=None
|
output_name: list[str]=None
|
||||||
output_tooltips: list[str]=None
|
output_tooltips: list[str]=None
|
||||||
|
output_matchtypes: list[str]=None
|
||||||
name: str=None
|
name: str=None
|
||||||
display_name: str=None
|
display_name: str=None
|
||||||
description: str=None
|
description: str=None
|
||||||
@ -1019,9 +1201,9 @@ class Schema:
|
|||||||
"""Display name of node."""
|
"""Display name of node."""
|
||||||
category: str = "sd"
|
category: str = "sd"
|
||||||
"""The category of the node, as per the "Add Node" menu."""
|
"""The category of the node, as per the "Add Node" menu."""
|
||||||
inputs: list[Input]=None
|
inputs: list[Input] = field(default_factory=list)
|
||||||
outputs: list[Output]=None
|
outputs: list[Output] = field(default_factory=list)
|
||||||
hidden: list[Hidden]=None
|
hidden: list[Hidden] = field(default_factory=list)
|
||||||
description: str=""
|
description: str=""
|
||||||
"""Node description, shown as a tooltip when hovering over the node."""
|
"""Node description, shown as a tooltip when hovering over the node."""
|
||||||
is_input_list: bool = False
|
is_input_list: bool = False
|
||||||
@ -1061,7 +1243,11 @@ class Schema:
|
|||||||
'''Validate the schema:
|
'''Validate the schema:
|
||||||
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
||||||
'''
|
'''
|
||||||
input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
|
nested_inputs: list[Input] = []
|
||||||
|
if self.inputs is not None:
|
||||||
|
for input in self.inputs:
|
||||||
|
nested_inputs.extend(input.get_all())
|
||||||
|
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
|
||||||
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
||||||
input_set = set(input_ids)
|
input_set = set(input_ids)
|
||||||
output_set = set(output_ids)
|
output_set = set(output_ids)
|
||||||
@ -1077,6 +1263,13 @@ class Schema:
|
|||||||
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
||||||
if len(issues) > 0:
|
if len(issues) > 0:
|
||||||
raise ValueError("\n".join(issues))
|
raise ValueError("\n".join(issues))
|
||||||
|
# validate inputs and outputs
|
||||||
|
if self.inputs is not None:
|
||||||
|
for input in self.inputs:
|
||||||
|
input.validate()
|
||||||
|
if self.outputs is not None:
|
||||||
|
for output in self.outputs:
|
||||||
|
output.validate()
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||||
@ -1102,19 +1295,10 @@ class Schema:
|
|||||||
if output.id is None:
|
if output.id is None:
|
||||||
output.id = f"_{i}_{output.io_type}_"
|
output.id = f"_{i}_{output.io_type}_"
|
||||||
|
|
||||||
def get_v1_info(self, cls) -> NodeInfoV1:
|
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
|
||||||
|
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
|
||||||
# get V1 inputs
|
# get V1 inputs
|
||||||
input = {
|
input = create_input_dict_v1(self.inputs, live_inputs)
|
||||||
"required": {}
|
|
||||||
}
|
|
||||||
if self.inputs:
|
|
||||||
for i in self.inputs:
|
|
||||||
if isinstance(i, DynamicInput):
|
|
||||||
dynamic_inputs = i.get_dynamic()
|
|
||||||
for d in dynamic_inputs:
|
|
||||||
add_to_dict_v1(d, input)
|
|
||||||
else:
|
|
||||||
add_to_dict_v1(i, input)
|
|
||||||
if self.hidden:
|
if self.hidden:
|
||||||
for hidden in self.hidden:
|
for hidden in self.hidden:
|
||||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||||
@ -1123,12 +1307,24 @@ class Schema:
|
|||||||
output_is_list = []
|
output_is_list = []
|
||||||
output_name = []
|
output_name = []
|
||||||
output_tooltips = []
|
output_tooltips = []
|
||||||
|
output_matchtypes = []
|
||||||
|
any_matchtypes = False
|
||||||
if self.outputs:
|
if self.outputs:
|
||||||
for o in self.outputs:
|
for o in self.outputs:
|
||||||
output.append(o.io_type)
|
output.append(o.io_type)
|
||||||
output_is_list.append(o.is_output_list)
|
output_is_list.append(o.is_output_list)
|
||||||
output_name.append(o.display_name if o.display_name else o.io_type)
|
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||||
output_tooltips.append(o.tooltip if o.tooltip else None)
|
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||||
|
# special handling for MatchType
|
||||||
|
if isinstance(o, MatchType.Output):
|
||||||
|
output_matchtypes.append(o.template.template_id)
|
||||||
|
any_matchtypes = True
|
||||||
|
else:
|
||||||
|
output_matchtypes.append(None)
|
||||||
|
|
||||||
|
# clear out lists that are all None
|
||||||
|
if not any_matchtypes:
|
||||||
|
output_matchtypes = None
|
||||||
|
|
||||||
info = NodeInfoV1(
|
info = NodeInfoV1(
|
||||||
input=input,
|
input=input,
|
||||||
@ -1137,6 +1333,7 @@ class Schema:
|
|||||||
output_is_list=output_is_list,
|
output_is_list=output_is_list,
|
||||||
output_name=output_name,
|
output_name=output_name,
|
||||||
output_tooltips=output_tooltips,
|
output_tooltips=output_tooltips,
|
||||||
|
output_matchtypes=output_matchtypes,
|
||||||
name=self.node_id,
|
name=self.node_id,
|
||||||
display_name=self.display_name,
|
display_name=self.display_name,
|
||||||
category=self.category,
|
category=self.category,
|
||||||
@ -1182,16 +1379,57 @@ class Schema:
|
|||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def add_to_dict_v1(i: Input, input: dict):
|
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
|
||||||
|
input = {
|
||||||
|
"required": {}
|
||||||
|
}
|
||||||
|
add_to_input_dict_v1(input, inputs, live_inputs)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
|
||||||
|
for i in inputs:
|
||||||
|
if isinstance(i, DynamicInput):
|
||||||
|
add_to_dict_v1(i, d)
|
||||||
|
if live_inputs is not None:
|
||||||
|
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||||
|
else:
|
||||||
|
add_to_dict_v1(i, d)
|
||||||
|
|
||||||
|
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
|
||||||
key = "optional" if i.optional else "required"
|
key = "optional" if i.optional else "required"
|
||||||
as_dict = i.as_dict()
|
as_dict = i.as_dict()
|
||||||
# for v1, we don't want to include the optional key
|
# for v1, we don't want to include the optional key
|
||||||
as_dict.pop("optional", None)
|
as_dict.pop("optional", None)
|
||||||
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
if dynamic_dict is None:
|
||||||
|
value = (i.get_io_type(), as_dict)
|
||||||
|
else:
|
||||||
|
value = (i.get_io_type(), as_dict, dynamic_dict)
|
||||||
|
d.setdefault(key, {})[i.id] = value
|
||||||
|
|
||||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||||
|
|
||||||
|
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||||
|
paths = v3_data.get("dynamic_paths", None)
|
||||||
|
if paths is None:
|
||||||
|
return values
|
||||||
|
values = values.copy()
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for key, path in paths.items():
|
||||||
|
parts = path.split(".")
|
||||||
|
current = result
|
||||||
|
|
||||||
|
for i, p in enumerate(parts):
|
||||||
|
is_last = (i == len(parts) - 1)
|
||||||
|
|
||||||
|
if is_last:
|
||||||
|
current[p] = values.pop(key, None)
|
||||||
|
else:
|
||||||
|
current = current.setdefault(p, {})
|
||||||
|
|
||||||
|
values.update(result)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||||
@ -1311,12 +1549,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
|
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
|
||||||
"""Creates clone of real node class to prevent monkey-patching."""
|
"""Creates clone of real node class to prevent monkey-patching."""
|
||||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||||
# set hidden
|
# set hidden
|
||||||
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
|
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
|
||||||
return type_clone
|
return type_clone
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1433,14 +1671,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
|
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
|
||||||
schema = cls.FINALIZE_SCHEMA()
|
schema = cls.FINALIZE_SCHEMA()
|
||||||
info = schema.get_v1_info(cls)
|
info = schema.get_v1_info(cls, live_inputs)
|
||||||
input = info.input
|
input = info.input
|
||||||
if not include_hidden:
|
if not include_hidden:
|
||||||
input.pop("hidden", None)
|
input.pop("hidden", None)
|
||||||
if return_schema:
|
if return_schema:
|
||||||
return input, schema
|
v3_data: V3Data = {}
|
||||||
|
dynamic = input.pop("dynamic_paths", None)
|
||||||
|
if dynamic is not None:
|
||||||
|
v3_data["dynamic_paths"] = dynamic
|
||||||
|
return input, schema, v3_data
|
||||||
return input
|
return input
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1513,7 +1755,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_inputs(cls, **kwargs) -> bool:
|
def validate_inputs(cls, **kwargs) -> bool | str:
|
||||||
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
|
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -1628,6 +1870,7 @@ __all__ = [
|
|||||||
"StyleModel",
|
"StyleModel",
|
||||||
"Gligen",
|
"Gligen",
|
||||||
"UpscaleModel",
|
"UpscaleModel",
|
||||||
|
"LatentUpscaleModel",
|
||||||
"Audio",
|
"Audio",
|
||||||
"Video",
|
"Video",
|
||||||
"SVG",
|
"SVG",
|
||||||
@ -1651,6 +1894,10 @@ __all__ = [
|
|||||||
"SEGS",
|
"SEGS",
|
||||||
"AnyType",
|
"AnyType",
|
||||||
"MultiType",
|
"MultiType",
|
||||||
|
# Dynamic Types
|
||||||
|
"MatchType",
|
||||||
|
# "DynamicCombo",
|
||||||
|
# "Autogrow",
|
||||||
# Other classes
|
# Other classes
|
||||||
"HiddenHolder",
|
"HiddenHolder",
|
||||||
"Hidden",
|
"Hidden",
|
||||||
@ -1661,4 +1908,5 @@ __all__ = [
|
|||||||
"NodeOutput",
|
"NodeOutput",
|
||||||
"add_to_dict_v1",
|
"add_to_dict_v1",
|
||||||
"add_to_dict_v3",
|
"add_to_dict_v3",
|
||||||
|
"V3Data",
|
||||||
]
|
]
|
||||||
|
|||||||
1
comfy_api/latest/_io_public.py
Normal file
1
comfy_api/latest/_io_public.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from ._io import * # noqa: F403
|
||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
@ -318,9 +319,10 @@ class AudioSaveHelper:
|
|||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
output_container.metadata[key] = value
|
output_container.metadata[key] = value
|
||||||
|
|
||||||
|
layout = "mono" if waveform.shape[0] == 1 else "stereo"
|
||||||
# Set up the output stream with appropriate properties
|
# Set up the output stream with appropriate properties
|
||||||
if format == "opus":
|
if format == "opus":
|
||||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
||||||
if quality == "64k":
|
if quality == "64k":
|
||||||
out_stream.bit_rate = 64000
|
out_stream.bit_rate = 64000
|
||||||
elif quality == "96k":
|
elif quality == "96k":
|
||||||
@ -332,7 +334,7 @@ class AudioSaveHelper:
|
|||||||
elif quality == "320k":
|
elif quality == "320k":
|
||||||
out_stream.bit_rate = 320000
|
out_stream.bit_rate = 320000
|
||||||
elif format == "mp3":
|
elif format == "mp3":
|
||||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
||||||
if quality == "V0":
|
if quality == "V0":
|
||||||
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||||
out_stream.codec_context.qscale = 1
|
out_stream.codec_context.qscale = 1
|
||||||
@ -341,12 +343,12 @@ class AudioSaveHelper:
|
|||||||
elif quality == "320k":
|
elif quality == "320k":
|
||||||
out_stream.bit_rate = 320000
|
out_stream.bit_rate = 320000
|
||||||
else: # format == "flac":
|
else: # format == "flac":
|
||||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
||||||
|
|
||||||
frame = av.AudioFrame.from_ndarray(
|
frame = av.AudioFrame.from_ndarray(
|
||||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||||
format="flt",
|
format="flt",
|
||||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
layout=layout,
|
||||||
)
|
)
|
||||||
frame.sample_rate = sample_rate
|
frame.sample_rate = sample_rate
|
||||||
frame.pts = 0
|
frame.pts = 0
|
||||||
@ -436,9 +438,19 @@ class PreviewUI3D(_UIOutput):
|
|||||||
def __init__(self, model_file, camera_info, **kwargs):
|
def __init__(self, model_file, camera_info, **kwargs):
|
||||||
self.model_file = model_file
|
self.model_file = model_file
|
||||||
self.camera_info = camera_info
|
self.camera_info = camera_info
|
||||||
|
self.bg_image_path = None
|
||||||
|
bg_image = kwargs.get("bg_image", None)
|
||||||
|
if bg_image is not None:
|
||||||
|
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
img = PILImage.fromarray(img_array)
|
||||||
|
temp_dir = folder_paths.get_temp_directory()
|
||||||
|
filename = f"bg_{uuid.uuid4().hex}.png"
|
||||||
|
bg_image_path = os.path.join(temp_dir, filename)
|
||||||
|
img.save(bg_image_path, compress_level=1)
|
||||||
|
self.bg_image_path = f"temp/{filename}"
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return {"result": [self.model_file, self.camera_info]}
|
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
|
||||||
|
|
||||||
|
|
||||||
class PreviewText(_UIOutput):
|
class PreviewText(_UIOutput):
|
||||||
|
|||||||
1
comfy_api/latest/_ui_public.py
Normal file
1
comfy_api/latest/_ui_public.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from ._ui import * # noqa: F403
|
||||||
@ -6,7 +6,7 @@ from comfy_api.latest import (
|
|||||||
)
|
)
|
||||||
from typing import Type, TYPE_CHECKING
|
from typing import Type, TYPE_CHECKING
|
||||||
from comfy_api.internal.async_to_sync import create_sync_class
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||||
@ -42,4 +42,8 @@ __all__ = [
|
|||||||
"InputImpl",
|
"InputImpl",
|
||||||
"Types",
|
"Types",
|
||||||
"ComfyExtension",
|
"ComfyExtension",
|
||||||
|
"io",
|
||||||
|
"IO",
|
||||||
|
"ui",
|
||||||
|
"UI",
|
||||||
]
|
]
|
||||||
|
|||||||
86
comfy_api_nodes/apis/kling_api.py
Normal file
86
comfy_api_nodes/apis/kling_api.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProText2VideoRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-video-o1")
|
||||||
|
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||||
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniParamImage(BaseModel):
|
||||||
|
image_url: str = Field(...)
|
||||||
|
type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniParamVideo(BaseModel):
|
||||||
|
video_url: str = Field(...)
|
||||||
|
refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
|
||||||
|
keep_original_sound: str = Field(..., description="'yes' or 'no'")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProFirstLastFrameRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-video-o1")
|
||||||
|
image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
|
||||||
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProReferences2VideoRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-video-o1")
|
||||||
|
aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||||
|
image_list: list[OmniParamImage] | None = Field(
|
||||||
|
None, max_length=7, description="Max length 4 when video is present."
|
||||||
|
)
|
||||||
|
video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
|
||||||
|
duration: str | None = Field(..., description="From 3 to 10.")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusVideoResult(BaseModel):
|
||||||
|
duration: str | None = Field(None, description="Total video duration")
|
||||||
|
id: str | None = Field(None, description="Generated video ID")
|
||||||
|
url: str | None = Field(None, description="URL for generated video")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusImageResult(BaseModel):
|
||||||
|
index: int = Field(..., description="Image Number,0-9")
|
||||||
|
url: str = Field(..., description="URL for generated image")
|
||||||
|
|
||||||
|
|
||||||
|
class OmniTaskStatusResults(BaseModel):
|
||||||
|
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||||
|
images: list[TaskStatusImageResult] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniTaskStatusResponseData(BaseModel):
|
||||||
|
created_at: int | None = Field(None, description="Task creation time")
|
||||||
|
updated_at: int | None = Field(None, description="Task update time")
|
||||||
|
task_status: str | None = None
|
||||||
|
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||||
|
task_id: str | None = Field(None, description="Task ID")
|
||||||
|
task_result: OmniTaskStatusResults | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniTaskStatusResponse(BaseModel):
|
||||||
|
code: int | None = Field(None, description="Error code")
|
||||||
|
message: str | None = Field(None, description="Error message")
|
||||||
|
request_id: str | None = Field(None, description="Request ID")
|
||||||
|
data: OmniTaskStatusResponseData | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniImageParamImage(BaseModel):
|
||||||
|
image: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProImageRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-image-o1")
|
||||||
|
resolution: str = Field(..., description="'1k' or '2k'")
|
||||||
|
aspect_ratio: str | None = Field(...)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
n: int | None = Field(1, le=9)
|
||||||
|
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||||
@ -4,13 +4,14 @@ For source of truth on the allowed permutations of request fields, please refere
|
|||||||
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
|
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from typing_extensions import override
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
KlingCameraControl,
|
KlingCameraControl,
|
||||||
KlingCameraConfig,
|
KlingCameraConfig,
|
||||||
@ -48,23 +49,33 @@ from comfy_api_nodes.apis import (
|
|||||||
KlingCharacterEffectModelName,
|
KlingCharacterEffectModelName,
|
||||||
KlingSingleImageEffectModelName,
|
KlingSingleImageEffectModelName,
|
||||||
)
|
)
|
||||||
|
from comfy_api_nodes.apis.kling_api import (
|
||||||
|
OmniImageParamImage,
|
||||||
|
OmniParamImage,
|
||||||
|
OmniParamVideo,
|
||||||
|
OmniProFirstLastFrameRequest,
|
||||||
|
OmniProImageRequest,
|
||||||
|
OmniProReferences2VideoRequest,
|
||||||
|
OmniProText2VideoRequest,
|
||||||
|
OmniTaskStatusResponse,
|
||||||
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
validate_image_dimensions,
|
ApiEndpoint,
|
||||||
|
download_url_to_image_tensor,
|
||||||
|
download_url_to_video_output,
|
||||||
|
get_number_of_images,
|
||||||
|
poll_op,
|
||||||
|
sync_op,
|
||||||
|
tensor_to_base64_string,
|
||||||
|
upload_audio_to_comfyapi,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
validate_image_aspect_ratio,
|
validate_image_aspect_ratio,
|
||||||
|
validate_image_dimensions,
|
||||||
|
validate_string,
|
||||||
validate_video_dimensions,
|
validate_video_dimensions,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
tensor_to_base64_string,
|
|
||||||
validate_string,
|
|
||||||
upload_audio_to_comfyapi,
|
|
||||||
download_url_to_image_tensor,
|
|
||||||
upload_video_to_comfyapi,
|
|
||||||
download_url_to_video_output,
|
|
||||||
sync_op,
|
|
||||||
ApiEndpoint,
|
|
||||||
poll_op,
|
|
||||||
)
|
)
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.latest import ComfyExtension, IO, Input
|
|
||||||
|
|
||||||
KLING_API_VERSION = "v1"
|
KLING_API_VERSION = "v1"
|
||||||
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
|
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
|
||||||
@ -202,6 +213,50 @@ VOICES_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_omni_prompt_references(prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Rewrites Kling Omni-style placeholders used in the app, like:
|
||||||
|
|
||||||
|
@image, @image1, @image2, ... @imageN
|
||||||
|
@video, @video1, @video2, ... @videoN
|
||||||
|
|
||||||
|
into the API-compatible form:
|
||||||
|
|
||||||
|
<<<image_1>>>, <<<image_2>>>, ...
|
||||||
|
<<<video_1>>>, <<<video_2>>>, ...
|
||||||
|
|
||||||
|
This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
|
||||||
|
"""
|
||||||
|
if not prompt:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _image_repl(match):
|
||||||
|
return f"<<<image_{match.group('idx') or '1'}>>>"
|
||||||
|
|
||||||
|
def _video_repl(match):
|
||||||
|
return f"<<<video_{match.group('idx') or '1'}>>>"
|
||||||
|
|
||||||
|
# (?<!\w) avoids matching e.g. "test@image.com"
|
||||||
|
# (?!\w) makes sure we only match @image / @image<digits> and not @imageFoo
|
||||||
|
prompt = re.sub(r"(?<!\w)@image(?P<idx>\d*)(?!\w)", _image_repl, prompt)
|
||||||
|
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
||||||
|
|
||||||
|
|
||||||
|
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
max_poll_attempts=160,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||||
|
|
||||||
|
|
||||||
def is_valid_camera_control_configs(configs: list[float]) -> bool:
|
def is_valid_camera_control_configs(configs: list[float]) -> bool:
|
||||||
"""Verifies that at least one camera control configuration is non-zero."""
|
"""Verifies that at least one camera control configuration is non-zero."""
|
||||||
return any(not math.isclose(value, 0.0) for value in configs)
|
return any(not math.isclose(value, 0.0) for value in configs)
|
||||||
@ -449,7 +504,7 @@ async def execute_video_effect(
|
|||||||
image_1: torch.Tensor,
|
image_1: torch.Tensor,
|
||||||
image_2: torch.Tensor | None = None,
|
image_2: torch.Tensor | None = None,
|
||||||
model_mode: KlingVideoGenMode | None = None,
|
model_mode: KlingVideoGenMode | None = None,
|
||||||
) -> tuple[VideoFromFile, str, str]:
|
) -> tuple[InputImpl.VideoFromFile, str, str]:
|
||||||
if dual_character:
|
if dual_character:
|
||||||
request_input_field = KlingDualCharacterEffectInput(
|
request_input_field = KlingDualCharacterEffectInput(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -736,6 +791,474 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProTextToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProTextToVideoNode",
|
||||||
|
display_name="Kling Omni Text to Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Use text prompts to generate videos with the latest Kling model.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
|
IO.Combo.Input("duration", options=[5, 10]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
duration: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProText2VideoRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
duration=str(duration),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProFirstLastFrameNode",
|
||||||
|
display_name="Kling Omni First-Last-Frame to Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("duration", options=["5", "10"]),
|
||||||
|
IO.Image.Input("first_frame"),
|
||||||
|
IO.Image.Input(
|
||||||
|
"end_frame",
|
||||||
|
optional=True,
|
||||||
|
tooltip="An optional end frame for the video. "
|
||||||
|
"This cannot be used simultaneously with 'reference_images'.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Up to 6 additional reference images.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
first_frame: Input.Image,
|
||||||
|
end_frame: Input.Image | None = None,
|
||||||
|
reference_images: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
if end_frame is not None and reference_images is not None:
|
||||||
|
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
||||||
|
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
||||||
|
image_list: list[OmniParamImage] = [
|
||||||
|
OmniParamImage(
|
||||||
|
image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0],
|
||||||
|
type="first_frame",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if end_frame is not None:
|
||||||
|
validate_image_dimensions(end_frame, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
|
||||||
|
image_list.append(
|
||||||
|
OmniParamImage(
|
||||||
|
image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0],
|
||||||
|
type="end_frame",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if reference_images is not None:
|
||||||
|
if get_number_of_images(reference_images) > 6:
|
||||||
|
raise ValueError("The maximum number of reference images allowed is 6.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
|
||||||
|
image_list.append(OmniParamImage(image_url=i))
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProFirstLastFrameRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
duration=str(duration),
|
||||||
|
image_list=image_list,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProImageToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProImageToVideoNode",
|
||||||
|
display_name="Kling Omni Image to Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Use up to 7 reference images to generate a video with the latest Kling model.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
|
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
tooltip="Up to 7 reference images.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
duration: int,
|
||||||
|
reference_images: Input.Image,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
if get_number_of_images(reference_images) > 7:
|
||||||
|
raise ValueError("The maximum number of reference images is 7.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
image_list: list[OmniParamImage] = []
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||||
|
image_list.append(OmniParamImage(image_url=i))
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProReferences2VideoRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
duration=str(duration),
|
||||||
|
image_list=image_list,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProVideoToVideoNode",
|
||||||
|
display_name="Kling Omni Video to Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
|
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||||
|
IO.Video.Input("reference_video", tooltip="Video to use as a reference."),
|
||||||
|
IO.Boolean.Input("keep_original_sound", default=True),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
tooltip="Up to 4 additional reference images.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
duration: int,
|
||||||
|
reference_video: Input.Video,
|
||||||
|
keep_original_sound: bool,
|
||||||
|
reference_images: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
|
||||||
|
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||||
|
image_list: list[OmniParamImage] = []
|
||||||
|
if reference_images is not None:
|
||||||
|
if get_number_of_images(reference_images) > 4:
|
||||||
|
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||||
|
image_list.append(OmniParamImage(image_url=i))
|
||||||
|
video_list = [
|
||||||
|
OmniParamVideo(
|
||||||
|
video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"),
|
||||||
|
refer_type="feature",
|
||||||
|
keep_original_sound="yes" if keep_original_sound else "no",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProReferences2VideoRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
duration=str(duration),
|
||||||
|
image_list=image_list if image_list else None,
|
||||||
|
video_list=video_list,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProEditVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProEditVideoNode",
|
||||||
|
display_name="Kling Omni Edit Video (Pro)",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
description="Edit an existing video with the latest model from Kling.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the video content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."),
|
||||||
|
IO.Boolean.Input("keep_original_sound", default=True),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
tooltip="Up to 4 additional reference images.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
video: Input.Video,
|
||||||
|
keep_original_sound: bool,
|
||||||
|
reference_images: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
|
||||||
|
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||||
|
image_list: list[OmniParamImage] = []
|
||||||
|
if reference_images is not None:
|
||||||
|
if get_number_of_images(reference_images) > 4:
|
||||||
|
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||||
|
image_list.append(OmniParamImage(image_url=i))
|
||||||
|
video_list = [
|
||||||
|
OmniParamVideo(
|
||||||
|
video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"),
|
||||||
|
refer_type="base",
|
||||||
|
keep_original_sound="yes" if keep_original_sound else "no",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProReferences2VideoRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
aspect_ratio=None,
|
||||||
|
duration=None,
|
||||||
|
image_list=image_list if image_list else None,
|
||||||
|
video_list=video_list,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return await finish_omni_video_task(cls, response)
|
||||||
|
|
||||||
|
|
||||||
|
class OmniProImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingOmniProImageNode",
|
||||||
|
display_name="Kling Omni Image (Pro)",
|
||||||
|
category="api node/image/Kling",
|
||||||
|
description="Create or edit images with the latest model from Kling.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-image-o1"]),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A text prompt describing the image content. "
|
||||||
|
"This can include both positive and negative descriptions.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1K", "2K"]),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"reference_images",
|
||||||
|
tooltip="Up to 10 additional reference images.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
resolution: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
reference_images: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
prompt = normalize_omni_prompt_references(prompt)
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
image_list: list[OmniImageParamImage] = []
|
||||||
|
if reference_images is not None:
|
||||||
|
if get_number_of_images(reference_images) > 10:
|
||||||
|
raise ValueError("The maximum number of reference images is 10.")
|
||||||
|
for i in reference_images:
|
||||||
|
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||||
|
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||||
|
image_list.append(OmniImageParamImage(image=i))
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
data=OmniProImageRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
resolution=resolution.lower(),
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
image_list=image_list if image_list else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
||||||
|
response_model=OmniTaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
||||||
|
|
||||||
|
|
||||||
class KlingCameraControlT2VNode(IO.ComfyNode):
|
class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
|
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
|
||||||
@ -1162,7 +1685,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
|||||||
category="api node/video/Kling",
|
category="api node/video/Kling",
|
||||||
description="Achieve different special effects when generating a video based on the effect_scene.",
|
description="Achieve different special effects when generating a video based on the effect_scene.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"),
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1",
|
||||||
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"effect_scene",
|
"effect_scene",
|
||||||
options=[i.value for i in KlingSingleImageEffectsScene],
|
options=[i.value for i in KlingSingleImageEffectsScene],
|
||||||
@ -1525,6 +2051,12 @@ class KlingExtension(ComfyExtension):
|
|||||||
KlingImageGenerationNode,
|
KlingImageGenerationNode,
|
||||||
KlingSingleImageVideoEffectNode,
|
KlingSingleImageVideoEffectNode,
|
||||||
KlingDualCharacterVideoEffectNode,
|
KlingDualCharacterVideoEffectNode,
|
||||||
|
OmniProTextToVideoNode,
|
||||||
|
OmniProFirstLastFrameNode,
|
||||||
|
OmniProImageToVideoNode,
|
||||||
|
OmniProVideoToVideoNode,
|
||||||
|
OmniProEditVideoNode,
|
||||||
|
# OmniProImageNode, # need support from backend
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -92,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -152,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -239,6 +241,7 @@ class PikaScenes(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -323,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -399,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -466,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -515,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -47,6 +47,7 @@ from .validation_utils import (
|
|||||||
validate_string,
|
validate_string,
|
||||||
validate_video_dimensions,
|
validate_video_dimensions,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
|
validate_video_frame_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -94,6 +95,7 @@ __all__ = [
|
|||||||
"validate_string",
|
"validate_string",
|
||||||
"validate_video_dimensions",
|
"validate_video_dimensions",
|
||||||
"validate_video_duration",
|
"validate_video_duration",
|
||||||
|
"validate_video_frame_count",
|
||||||
# Misc functions
|
# Misc functions
|
||||||
"get_fs_object_size",
|
"get_fs_object_size",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -2,8 +2,8 @@ import asyncio
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Callable, Optional, Union
|
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.model_management import processing_interrupted
|
from comfy.model_management import processing_interrupted
|
||||||
@ -35,12 +35,12 @@ def default_base_url() -> str:
|
|||||||
|
|
||||||
async def sleep_with_interrupt(
|
async def sleep_with_interrupt(
|
||||||
seconds: float,
|
seconds: float,
|
||||||
node_cls: Optional[type[IO.ComfyNode]],
|
node_cls: type[IO.ComfyNode] | None,
|
||||||
label: Optional[str] = None,
|
label: str | None = None,
|
||||||
start_ts: Optional[float] = None,
|
start_ts: float | None = None,
|
||||||
estimated_total: Optional[int] = None,
|
estimated_total: int | None = None,
|
||||||
*,
|
*,
|
||||||
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
|
display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Sleep in 1s slices while:
|
Sleep in 1s slices while:
|
||||||
@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
|
|||||||
return mime_type.split("/")[-1].lower()
|
return mime_type.split("/")[-1].lower()
|
||||||
|
|
||||||
|
|
||||||
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
|
def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
||||||
if isinstance(path_or_object, str):
|
if isinstance(path_or_object, str):
|
||||||
return os.path.getsize(path_or_object)
|
return os.path.getsize(path_or_object)
|
||||||
return len(path_or_object.getvalue())
|
return len(path_or_object.getvalue())
|
||||||
|
|||||||
@ -4,10 +4,11 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
from typing import Any, Literal, TypeVar
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -37,8 +38,8 @@ class ApiEndpoint:
|
|||||||
path: str,
|
path: str,
|
||||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||||||
*,
|
*,
|
||||||
query_params: Optional[dict[str, Any]] = None,
|
query_params: dict[str, Any] | None = None,
|
||||||
headers: Optional[dict[str, str]] = None,
|
headers: dict[str, str] | None = None,
|
||||||
):
|
):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.method = method
|
self.method = method
|
||||||
@ -52,18 +53,18 @@ class _RequestConfig:
|
|||||||
endpoint: ApiEndpoint
|
endpoint: ApiEndpoint
|
||||||
timeout: float
|
timeout: float
|
||||||
content_type: str
|
content_type: str
|
||||||
data: Optional[dict[str, Any]]
|
data: dict[str, Any] | None
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||||
multipart_parser: Optional[Callable]
|
multipart_parser: Callable | None
|
||||||
max_retries: int
|
max_retries: int
|
||||||
retry_delay: float
|
retry_delay: float
|
||||||
retry_backoff: float
|
retry_backoff: float
|
||||||
wait_label: str = "Waiting"
|
wait_label: str = "Waiting"
|
||||||
monitor_progress: bool = True
|
monitor_progress: bool = True
|
||||||
estimated_total: Optional[int] = None
|
estimated_total: int | None = None
|
||||||
final_label_on_success: Optional[str] = "Completed"
|
final_label_on_success: str | None = "Completed"
|
||||||
progress_origin_ts: Optional[float] = None
|
progress_origin_ts: float | None = None
|
||||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -71,10 +72,10 @@ class _PollUIState:
|
|||||||
started: float
|
started: float
|
||||||
status_label: str = "Queued"
|
status_label: str = "Queued"
|
||||||
is_queued: bool = True
|
is_queued: bool = True
|
||||||
price: Optional[float] = None
|
price: float | None = None
|
||||||
estimated_duration: Optional[int] = None
|
estimated_duration: int | None = None
|
||||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
active_since: float | None = None # start time of current active interval (None if queued)
|
||||||
|
|
||||||
|
|
||||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
@ -87,20 +88,20 @@ async def sync_op(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
response_model: Type[M],
|
response_model: type[M],
|
||||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||||
data: Optional[BaseModel] = None,
|
data: BaseModel | None = None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
timeout: float = 3600.0,
|
timeout: float = 3600.0,
|
||||||
multipart_parser: Optional[Callable] = None,
|
multipart_parser: Callable | None = None,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
wait_label: str = "Waiting for server",
|
wait_label: str = "Waiting for server",
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
final_label_on_success: Optional[str] = "Completed",
|
final_label_on_success: str | None = "Completed",
|
||||||
progress_origin_ts: Optional[float] = None,
|
progress_origin_ts: float | None = None,
|
||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
) -> M:
|
) -> M:
|
||||||
raw = await sync_op_raw(
|
raw = await sync_op_raw(
|
||||||
@ -131,22 +132,22 @@ async def poll_op(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
poll_endpoint: ApiEndpoint,
|
poll_endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
response_model: Type[M],
|
response_model: type[M],
|
||||||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
status_extractor: Callable[[M | Any], str | int | None],
|
||||||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
progress_extractor: Callable[[M | Any], int | None] | None = None,
|
||||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
completed_statuses: list[str | int] | None = None,
|
||||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
failed_statuses: list[str | int] | None = None,
|
||||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
queued_statuses: list[str | int] | None = None,
|
||||||
data: Optional[BaseModel] = None,
|
data: BaseModel | None = None,
|
||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 120,
|
max_poll_attempts: int = 120,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 3,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 2.0,
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
) -> M:
|
) -> M:
|
||||||
raw = await poll_op_raw(
|
raw = await poll_op_raw(
|
||||||
@ -178,22 +179,22 @@ async def sync_op_raw(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
data: dict[str, Any] | BaseModel | None = None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
timeout: float = 3600.0,
|
timeout: float = 3600.0,
|
||||||
multipart_parser: Optional[Callable] = None,
|
multipart_parser: Callable | None = None,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
wait_label: str = "Waiting for server",
|
wait_label: str = "Waiting for server",
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
as_binary: bool = False,
|
as_binary: bool = False,
|
||||||
final_label_on_success: Optional[str] = "Completed",
|
final_label_on_success: str | None = "Completed",
|
||||||
progress_origin_ts: Optional[float] = None,
|
progress_origin_ts: float | None = None,
|
||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
) -> Union[dict[str, Any], bytes]:
|
) -> dict[str, Any] | bytes:
|
||||||
"""
|
"""
|
||||||
Make a single network request.
|
Make a single network request.
|
||||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||||
@ -229,21 +230,21 @@ async def poll_op_raw(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
poll_endpoint: ApiEndpoint,
|
poll_endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
status_extractor: Callable[[dict[str, Any]], str | int | None],
|
||||||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
|
||||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
completed_statuses: list[str | int] | None = None,
|
||||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
failed_statuses: list[str | int] | None = None,
|
||||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
queued_statuses: list[str | int] | None = None,
|
||||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
data: dict[str, Any] | BaseModel | None = None,
|
||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 120,
|
max_poll_attempts: int = 120,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 3,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 2.0,
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -261,7 +262,7 @@ async def poll_op_raw(
|
|||||||
consumed_attempts = 0 # counts only non-queued polls
|
consumed_attempts = 0 # counts only non-queued polls
|
||||||
|
|
||||||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||||||
last_progress: Optional[int] = None
|
last_progress: int | None = None
|
||||||
|
|
||||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||||
stop_ticker = asyncio.Event()
|
stop_ticker = asyncio.Event()
|
||||||
@ -420,10 +421,10 @@ async def poll_op_raw(
|
|||||||
|
|
||||||
def _display_text(
|
def _display_text(
|
||||||
node_cls: type[IO.ComfyNode],
|
node_cls: type[IO.ComfyNode],
|
||||||
text: Optional[str],
|
text: str | None,
|
||||||
*,
|
*,
|
||||||
status: Optional[Union[str, int]] = None,
|
status: str | int | None = None,
|
||||||
price: Optional[float] = None,
|
price: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
display_lines: list[str] = []
|
display_lines: list[str] = []
|
||||||
if status:
|
if status:
|
||||||
@ -440,13 +441,13 @@ def _display_text(
|
|||||||
|
|
||||||
def _display_time_progress(
|
def _display_time_progress(
|
||||||
node_cls: type[IO.ComfyNode],
|
node_cls: type[IO.ComfyNode],
|
||||||
status: Optional[Union[str, int]],
|
status: str | int | None,
|
||||||
elapsed_seconds: int,
|
elapsed_seconds: int,
|
||||||
estimated_total: Optional[int] = None,
|
estimated_total: int | None = None,
|
||||||
*,
|
*,
|
||||||
price: Optional[float] = None,
|
price: float | None = None,
|
||||||
is_queued: Optional[bool] = None,
|
is_queued: bool | None = None,
|
||||||
processing_elapsed_seconds: Optional[int] = None,
|
processing_elapsed_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||||
@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
|||||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||||
|
|
||||||
|
|
||||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
|
||||||
params = dict(endpoint_params or {})
|
params = dict(endpoint_params or {})
|
||||||
if method.upper() == "GET" and data:
|
if method.upper() == "GET" and data:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
|||||||
def _snapshot_request_body_for_logging(
|
def _snapshot_request_body_for_logging(
|
||||||
content_type: str,
|
content_type: str,
|
||||||
method: str,
|
method: str,
|
||||||
data: Optional[dict[str, Any]],
|
data: dict[str, Any] | None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
files: dict[str, Any] | list[tuple[str, Any]] | None,
|
||||||
) -> Optional[Union[dict[str, Any], str]]:
|
) -> dict[str, Any] | str | None:
|
||||||
if method.upper() == "GET":
|
if method.upper() == "GET":
|
||||||
return None
|
return None
|
||||||
if content_type == "multipart/form-data":
|
if content_type == "multipart/form-data":
|
||||||
@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
attempt = 0
|
attempt = 0
|
||||||
delay = cfg.retry_delay
|
delay = cfg.retry_delay
|
||||||
operation_succeeded: bool = False
|
operation_succeeded: bool = False
|
||||||
final_elapsed_seconds: Optional[int] = None
|
final_elapsed_seconds: int | None = None
|
||||||
extracted_price: Optional[float] = None
|
extracted_price: float | None = None
|
||||||
while True:
|
while True:
|
||||||
attempt += 1
|
attempt += 1
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
monitor_task: Optional[asyncio.Task] = None
|
monitor_task: asyncio.Task | None = None
|
||||||
sess: Optional[aiohttp.ClientSession] = None
|
sess: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||||
@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
def _validate_or_raise(response_model: type[M], payload: Any) -> M:
|
||||||
try:
|
try:
|
||||||
return response_model.model_validate(payload)
|
return response_model.model_validate(payload)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
|||||||
|
|
||||||
|
|
||||||
def _wrap_model_extractor(
|
def _wrap_model_extractor(
|
||||||
response_model: Type[M],
|
response_model: type[M],
|
||||||
extractor: Optional[Callable[[M], Any]],
|
extractor: Callable[[M], Any] | None,
|
||||||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
) -> Callable[[dict[str, Any]], Any] | None:
|
||||||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||||||
Validates the dict into `response_model` before invoking `extractor`.
|
Validates the dict into `response_model` before invoking `extractor`.
|
||||||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||||||
@ -929,10 +930,10 @@ def _wrap_model_extractor(
|
|||||||
return _wrapped
|
return _wrapped
|
||||||
|
|
||||||
|
|
||||||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
|
||||||
if not values:
|
if not values:
|
||||||
return set()
|
return set()
|
||||||
out: set[Union[str, int]] = set()
|
out: set[str | int] = set()
|
||||||
for v in values:
|
for v in values:
|
||||||
nv = _normalize_status_value(v)
|
nv = _normalize_status_value(v)
|
||||||
if nv is not None:
|
if nv is not None:
|
||||||
@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
def _normalize_status_value(val: str | int | None) -> str | int | None:
|
||||||
if isinstance(val, str):
|
if isinstance(val, str):
|
||||||
return val.strip().lower()
|
return val.strip().lower()
|
||||||
return val
|
return val
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import math
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,8 +11,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from comfy.utils import common_upscale
|
from comfy.utils import common_upscale
|
||||||
from comfy_api.latest import Input, InputImpl
|
from comfy_api.latest import Input, InputImpl, Types
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
|
||||||
|
|
||||||
from ._helpers import mimetype_to_extension
|
from ._helpers import mimetype_to_extension
|
||||||
|
|
||||||
@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
|
|||||||
|
|
||||||
def tensor_to_bytesio(
|
def tensor_to_bytesio(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
total_pixels: int = 2048 * 2048,
|
total_pixels: int = 2048 * 2048,
|
||||||
mime_type: str = "image/png",
|
mime_type: str = "image/png",
|
||||||
) -> BytesIO:
|
) -> BytesIO:
|
||||||
@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
|
|||||||
|
|
||||||
def video_to_base64_string(
|
def video_to_base64_string(
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
container_format: VideoContainer = None,
|
container_format: Types.VideoContainer | None = None,
|
||||||
codec: VideoCodec = None
|
codec: Types.VideoCodec | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Converts a video input to a base64 string.
|
Converts a video input to a base64 string.
|
||||||
@ -189,12 +187,11 @@ def video_to_base64_string(
|
|||||||
codec: Optional codec to use (defaults to video.codec if available)
|
codec: Optional codec to use (defaults to video.codec if available)
|
||||||
"""
|
"""
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
|
video.save_to(
|
||||||
# Use provided format/codec if specified, otherwise use video's own if available
|
video_bytes_io,
|
||||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
|
||||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
|
||||||
|
)
|
||||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|||||||
@ -3,15 +3,15 @@ import contextlib
|
|||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, Optional, Union
|
from typing import IO
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||||
|
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.latest import IO as COMFY_IO
|
from comfy_api.latest import IO as COMFY_IO
|
||||||
|
from comfy_api.latest import InputImpl
|
||||||
|
|
||||||
from . import request_logger
|
from . import request_logger
|
||||||
from ._helpers import (
|
from ._helpers import (
|
||||||
@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
|||||||
|
|
||||||
async def download_url_to_bytesio(
|
async def download_url_to_bytesio(
|
||||||
url: str,
|
url: str,
|
||||||
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
|
dest: BytesIO | IO[bytes] | str | Path | None,
|
||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: float | None = None,
|
||||||
max_retries: int = 5,
|
max_retries: int = 5,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
@ -71,10 +71,10 @@ async def download_url_to_bytesio(
|
|||||||
|
|
||||||
is_path_sink = isinstance(dest, (str, Path))
|
is_path_sink = isinstance(dest, (str, Path))
|
||||||
fhandle = None
|
fhandle = None
|
||||||
session: Optional[aiohttp.ClientSession] = None
|
session: aiohttp.ClientSession | None = None
|
||||||
stop_evt: Optional[asyncio.Event] = None
|
stop_evt: asyncio.Event | None = None
|
||||||
monitor_task: Optional[asyncio.Task] = None
|
monitor_task: asyncio.Task | None = None
|
||||||
req_task: Optional[asyncio.Task] = None
|
req_task: asyncio.Task | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
@ -234,11 +234,11 @@ async def download_url_to_video_output(
|
|||||||
timeout: float = None,
|
timeout: float = None,
|
||||||
max_retries: int = 5,
|
max_retries: int = 5,
|
||||||
cls: type[COMFY_IO.ComfyNode] = None,
|
cls: type[COMFY_IO.ComfyNode] = None,
|
||||||
) -> VideoFromFile:
|
) -> InputImpl.VideoFromFile:
|
||||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||||
result = BytesIO()
|
result = BytesIO()
|
||||||
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||||
return VideoFromFile(result)
|
return InputImpl.VideoFromFile(result)
|
||||||
|
|
||||||
|
|
||||||
async def download_url_as_bytesio(
|
async def download_url_as_bytesio(
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
|||||||
@ -4,15 +4,13 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from comfy_api.latest import IO, Input
|
from comfy_api.latest import IO, Input, Types
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
|
||||||
|
|
||||||
from . import request_logger
|
from . import request_logger
|
||||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||||
@ -32,7 +30,7 @@ from .conversions import (
|
|||||||
|
|
||||||
class UploadRequest(BaseModel):
|
class UploadRequest(BaseModel):
|
||||||
file_name: str = Field(..., description="Filename to upload")
|
file_name: str = Field(..., description="Filename to upload")
|
||||||
content_type: Optional[str] = Field(
|
content_type: str | None = Field(
|
||||||
None,
|
None,
|
||||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||||
)
|
)
|
||||||
@ -56,7 +54,7 @@ async def upload_images_to_comfyapi(
|
|||||||
Uploads images to ComfyUI API and returns download URLs.
|
Uploads images to ComfyUI API and returns download URLs.
|
||||||
To upload multiple images, stack them in the batch dimension first.
|
To upload multiple images, stack them in the batch dimension first.
|
||||||
"""
|
"""
|
||||||
# if batch, try to upload each file if max_images is greater than 0
|
# if batched, try to upload each file if max_images is greater than 0
|
||||||
download_urls: list[str] = []
|
download_urls: list[str] = []
|
||||||
is_batch = len(image.shape) > 3
|
is_batch = len(image.shape) > 3
|
||||||
batch_len = image.shape[0] if is_batch else 1
|
batch_len = image.shape[0] if is_batch else 1
|
||||||
@ -100,9 +98,10 @@ async def upload_video_to_comfyapi(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
*,
|
*,
|
||||||
container: VideoContainer = VideoContainer.MP4,
|
container: Types.VideoContainer = Types.VideoContainer.MP4,
|
||||||
codec: VideoCodec = VideoCodec.H264,
|
codec: Types.VideoCodec = Types.VideoCodec.H264,
|
||||||
max_duration: Optional[int] = None,
|
max_duration: int | None = None,
|
||||||
|
wait_label: str | None = "Uploading",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Uploads a single video to ComfyUI API and returns its download URL.
|
Uploads a single video to ComfyUI API and returns its download URL.
|
||||||
@ -127,7 +126,7 @@ async def upload_video_to_comfyapi(
|
|||||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
|
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
||||||
|
|
||||||
|
|
||||||
async def upload_file_to_comfyapi(
|
async def upload_file_to_comfyapi(
|
||||||
@ -219,7 +218,7 @@ async def upload_file(
|
|||||||
return
|
return
|
||||||
|
|
||||||
monitor_task = asyncio.create_task(_monitor())
|
monitor_task = asyncio.create_task(_monitor())
|
||||||
sess: Optional[aiohttp.ClientSession] = None
|
sess: aiohttp.ClientSession | None = None
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
|
|||||||
@ -1,9 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.input.video_types import VideoInput
|
|
||||||
from comfy_api.latest import Input
|
from comfy_api.latest import Input
|
||||||
|
|
||||||
|
|
||||||
@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
|||||||
|
|
||||||
def validate_image_dimensions(
|
def validate_image_dimensions(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
min_width: Optional[int] = None,
|
min_width: int | None = None,
|
||||||
max_width: Optional[int] = None,
|
max_width: int | None = None,
|
||||||
min_height: Optional[int] = None,
|
min_height: int | None = None,
|
||||||
max_height: Optional[int] = None,
|
max_height: int | None = None,
|
||||||
):
|
):
|
||||||
height, width = get_image_dimensions(image)
|
height, width = get_image_dimensions(image)
|
||||||
|
|
||||||
@ -37,8 +35,8 @@ def validate_image_dimensions(
|
|||||||
|
|
||||||
def validate_image_aspect_ratio(
|
def validate_image_aspect_ratio(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||||
*,
|
*,
|
||||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
|
|||||||
|
|
||||||
def validate_aspect_ratio_string(
|
def validate_aspect_ratio_string(
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||||
*,
|
*,
|
||||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
|
|||||||
|
|
||||||
def validate_video_dimensions(
|
def validate_video_dimensions(
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
min_width: Optional[int] = None,
|
min_width: int | None = None,
|
||||||
max_width: Optional[int] = None,
|
max_width: int | None = None,
|
||||||
min_height: Optional[int] = None,
|
min_height: int | None = None,
|
||||||
max_height: Optional[int] = None,
|
max_height: int | None = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
width, height = video.get_dimensions()
|
width, height = video.get_dimensions()
|
||||||
@ -120,8 +118,8 @@ def validate_video_dimensions(
|
|||||||
|
|
||||||
def validate_video_duration(
|
def validate_video_duration(
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
min_duration: Optional[float] = None,
|
min_duration: float | None = None,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: float | None = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
duration = video.get_duration()
|
duration = video.get_duration()
|
||||||
@ -136,6 +134,23 @@ def validate_video_duration(
|
|||||||
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_video_frame_count(
|
||||||
|
video: Input.Video,
|
||||||
|
min_frame_count: int | None = None,
|
||||||
|
max_frame_count: int | None = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
frame_count = video.get_frame_count()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error getting frame count of video: %s", e)
|
||||||
|
return
|
||||||
|
|
||||||
|
if min_frame_count is not None and min_frame_count > frame_count:
|
||||||
|
raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
|
||||||
|
if max_frame_count is not None and frame_count > max_frame_count:
|
||||||
|
raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
|
||||||
|
|
||||||
|
|
||||||
def get_number_of_images(images):
|
def get_number_of_images(images):
|
||||||
if isinstance(images, torch.Tensor):
|
if isinstance(images, torch.Tensor):
|
||||||
return images.shape[0] if images.ndim >= 4 else 1
|
return images.shape[0] if images.ndim >= 4 else 1
|
||||||
@ -144,8 +159,8 @@ def get_number_of_images(images):
|
|||||||
|
|
||||||
def validate_audio_duration(
|
def validate_audio_duration(
|
||||||
audio: Input.Audio,
|
audio: Input.Audio,
|
||||||
min_duration: Optional[float] = None,
|
min_duration: float | None = None,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
sr = int(audio["sample_rate"])
|
sr = int(audio["sample_rate"])
|
||||||
dur = int(audio["waveform"].shape[-1]) / sr
|
dur = int(audio["waveform"].shape[-1]) / sr
|
||||||
@ -177,7 +192,7 @@ def validate_string(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
def validate_container_format_is_mp4(video: Input.Video) -> None:
|
||||||
"""Validates video container format is MP4."""
|
"""Validates video container format is MP4."""
|
||||||
container_format = video.get_container_format()
|
container_format = video.get_container_format()
|
||||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||||
@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
|||||||
def _assert_ratio_bounds(
|
def _assert_ratio_bounds(
|
||||||
ar: float,
|
ar: float,
|
||||||
*,
|
*,
|
||||||
min_ratio: Optional[tuple[float, float]] = None,
|
min_ratio: tuple[float, float] | None = None,
|
||||||
max_ratio: Optional[tuple[float, float]] = None,
|
max_ratio: tuple[float, float] | None = None,
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from comfy_api.latest import IO
|
||||||
|
|
||||||
|
|
||||||
def validate_node_input(
|
def validate_node_input(
|
||||||
@ -23,6 +24,11 @@ def validate_node_input(
|
|||||||
if not received_type != input_type:
|
if not received_type != input_type:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# If the received type or input_type is a MatchType, we can return True immediately;
|
||||||
|
# validation for this is handled by the frontend
|
||||||
|
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
|
||||||
|
return True
|
||||||
|
|
||||||
# Not equal, and not strings
|
# Not equal, and not strings
|
||||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -6,65 +6,80 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import os
|
import os
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import logging
|
import logging
|
||||||
from comfy.cli_args import args
|
from typing_extensions import override
|
||||||
from comfy.comfy_types import FileLocator
|
from comfy_api.latest import ComfyExtension, IO, UI
|
||||||
|
|
||||||
class EmptyLatentAudio:
|
class EmptyLatentAudio(IO.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.device = comfy.model_management.intermediate_device()
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="EmptyLatentAudio",
|
||||||
|
display_name="Empty Latent Audio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||||
|
IO.Int.Input(
|
||||||
|
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Latent.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
|
||||||
|
|
||||||
def generate(self, seconds, batch_size):
|
|
||||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples":latent, "type": "audio"}, )
|
return IO.NodeOutput({"samples":latent, "type": "audio"})
|
||||||
|
|
||||||
class ConditioningStableAudio:
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningStableAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
return IO.Schema(
|
||||||
"negative": ("CONDITIONING", ),
|
node_id="ConditioningStableAudio",
|
||||||
"seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
|
category="conditioning",
|
||||||
"seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
|
inputs=[
|
||||||
}}
|
IO.Conditioning.Input("positive"),
|
||||||
|
IO.Conditioning.Input("negative"),
|
||||||
|
IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
|
||||||
|
IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Conditioning.Output(display_name="positive"),
|
||||||
|
IO.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
@classmethod
|
||||||
RETURN_NAMES = ("positive", "negative")
|
def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
|
||||||
|
|
||||||
FUNCTION = "append"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning"
|
|
||||||
|
|
||||||
def append(self, positive, negative, seconds_start, seconds_total):
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
||||||
return (positive, negative)
|
return IO.NodeOutput(positive, negative)
|
||||||
|
|
||||||
class VAEEncodeAudio:
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class VAEEncodeAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
|
return IO.Schema(
|
||||||
RETURN_TYPES = ("LATENT",)
|
node_id="VAEEncodeAudio",
|
||||||
FUNCTION = "encode"
|
display_name="VAE Encode Audio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Audio.Input("audio"),
|
||||||
|
IO.Vae.Input("vae"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Latent.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
@classmethod
|
||||||
|
def execute(cls, vae, audio) -> IO.NodeOutput:
|
||||||
def encode(self, vae, audio):
|
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
if 44100 != sample_rate:
|
if 44100 != sample_rate:
|
||||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
||||||
@ -72,213 +87,134 @@ class VAEEncodeAudio:
|
|||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
|
|
||||||
t = vae.encode(waveform.movedim(1, -1))
|
t = vae.encode(waveform.movedim(1, -1))
|
||||||
return ({"samples":t}, )
|
return IO.NodeOutput({"samples":t})
|
||||||
|
|
||||||
class VAEDecodeAudio:
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class VAEDecodeAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
return IO.Schema(
|
||||||
RETURN_TYPES = ("AUDIO",)
|
node_id="VAEDecodeAudio",
|
||||||
FUNCTION = "decode"
|
display_name="VAE Decode Audio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Latent.Input("samples"),
|
||||||
|
IO.Vae.Input("vae"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
@classmethod
|
||||||
|
def execute(cls, vae, samples) -> IO.NodeOutput:
|
||||||
def decode(self, vae, samples):
|
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||||
std[std < 1.0] = 1.0
|
std[std < 1.0] = 1.0
|
||||||
audio /= std
|
audio /= std
|
||||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
|
||||||
|
|
||||||
|
decode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
|
class SaveAudio(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
filename_prefix += self.prefix_append
|
def define_schema(cls):
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
return IO.Schema(
|
||||||
results: list[FileLocator] = []
|
node_id="SaveAudio",
|
||||||
|
display_name="Save Audio (FLAC)",
|
||||||
# Prepare metadata dictionary
|
category="audio",
|
||||||
metadata = {}
|
inputs=[
|
||||||
if not args.disable_metadata:
|
IO.Audio.Input("audio"),
|
||||||
if prompt is not None:
|
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||||
metadata["prompt"] = json.dumps(prompt)
|
],
|
||||||
if extra_pnginfo is not None:
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
for x in extra_pnginfo:
|
is_output_node=True,
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
)
|
||||||
|
|
||||||
# Opus supported sample rates
|
|
||||||
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
|
||||||
|
|
||||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
|
||||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
|
||||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
|
||||||
output_path = os.path.join(full_output_folder, file)
|
|
||||||
|
|
||||||
# Use original sample rate initially
|
|
||||||
sample_rate = audio["sample_rate"]
|
|
||||||
|
|
||||||
# Handle Opus sample rate requirements
|
|
||||||
if format == "opus":
|
|
||||||
if sample_rate > 48000:
|
|
||||||
sample_rate = 48000
|
|
||||||
elif sample_rate not in OPUS_RATES:
|
|
||||||
# Find the next highest supported rate
|
|
||||||
for rate in sorted(OPUS_RATES):
|
|
||||||
if rate > sample_rate:
|
|
||||||
sample_rate = rate
|
|
||||||
break
|
|
||||||
if sample_rate not in OPUS_RATES: # Fallback if still not supported
|
|
||||||
sample_rate = 48000
|
|
||||||
|
|
||||||
# Resample if necessary
|
|
||||||
if sample_rate != audio["sample_rate"]:
|
|
||||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
|
||||||
|
|
||||||
# Create output with specified format
|
|
||||||
output_buffer = io.BytesIO()
|
|
||||||
output_container = av.open(output_buffer, mode='w', format=format)
|
|
||||||
|
|
||||||
# Set metadata on the container
|
|
||||||
for key, value in metadata.items():
|
|
||||||
output_container.metadata[key] = value
|
|
||||||
|
|
||||||
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
|
|
||||||
# Set up the output stream with appropriate properties
|
|
||||||
if format == "opus":
|
|
||||||
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
|
||||||
if quality == "64k":
|
|
||||||
out_stream.bit_rate = 64000
|
|
||||||
elif quality == "96k":
|
|
||||||
out_stream.bit_rate = 96000
|
|
||||||
elif quality == "128k":
|
|
||||||
out_stream.bit_rate = 128000
|
|
||||||
elif quality == "192k":
|
|
||||||
out_stream.bit_rate = 192000
|
|
||||||
elif quality == "320k":
|
|
||||||
out_stream.bit_rate = 320000
|
|
||||||
elif format == "mp3":
|
|
||||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
|
||||||
if quality == "V0":
|
|
||||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
|
||||||
out_stream.codec_context.qscale = 1
|
|
||||||
elif quality == "128k":
|
|
||||||
out_stream.bit_rate = 128000
|
|
||||||
elif quality == "320k":
|
|
||||||
out_stream.bit_rate = 320000
|
|
||||||
else: #format == "flac":
|
|
||||||
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
|
||||||
|
|
||||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
|
|
||||||
frame.sample_rate = sample_rate
|
|
||||||
frame.pts = 0
|
|
||||||
output_container.mux(out_stream.encode(frame))
|
|
||||||
|
|
||||||
# Flush encoder
|
|
||||||
output_container.mux(out_stream.encode(None))
|
|
||||||
|
|
||||||
# Close containers
|
|
||||||
output_container.close()
|
|
||||||
|
|
||||||
# Write the output to file
|
|
||||||
output_buffer.seek(0)
|
|
||||||
with open(output_path, 'wb') as f:
|
|
||||||
f.write(output_buffer.getbuffer())
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"filename": file,
|
|
||||||
"subfolder": subfolder,
|
|
||||||
"type": self.type
|
|
||||||
})
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
return { "ui": { "audio": results } }
|
|
||||||
|
|
||||||
class SaveAudio:
|
|
||||||
def __init__(self):
|
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
|
||||||
self.type = "output"
|
|
||||||
self.prefix_append = ""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
|
||||||
return {"required": { "audio": ("AUDIO", ),
|
return IO.NodeOutput(
|
||||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
|
||||||
},
|
)
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
save_flac = execute # TODO: remove
|
||||||
FUNCTION = "save_flac"
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "audio"
|
class SaveAudioMP3(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
|
def define_schema(cls):
|
||||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
|
return IO.Schema(
|
||||||
|
node_id="SaveAudioMP3",
|
||||||
class SaveAudioMP3:
|
display_name="Save Audio (MP3)",
|
||||||
def __init__(self):
|
category="audio",
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
inputs=[
|
||||||
self.type = "output"
|
IO.Audio.Input("audio"),
|
||||||
self.prefix_append = ""
|
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||||
|
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
|
||||||
return {"required": { "audio": ("AUDIO", ),
|
return IO.NodeOutput(
|
||||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||||
},
|
)
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
)
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
save_mp3 = execute # TODO: remove
|
||||||
FUNCTION = "save_mp3"
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "audio"
|
class SaveAudioOpus(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
|
def define_schema(cls):
|
||||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
return IO.Schema(
|
||||||
|
node_id="SaveAudioOpus",
|
||||||
class SaveAudioOpus:
|
display_name="Save Audio (Opus)",
|
||||||
def __init__(self):
|
category="audio",
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
inputs=[
|
||||||
self.type = "output"
|
IO.Audio.Input("audio"),
|
||||||
self.prefix_append = ""
|
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||||
|
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
|
||||||
return {"required": { "audio": ("AUDIO", ),
|
return IO.NodeOutput(
|
||||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||||
},
|
)
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
)
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
save_opus = execute # TODO: remove
|
||||||
FUNCTION = "save_opus"
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "audio"
|
class PreviewAudio(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
|
def define_schema(cls):
|
||||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
return IO.Schema(
|
||||||
|
node_id="PreviewAudio",
|
||||||
class PreviewAudio(SaveAudio):
|
display_name="Preview Audio",
|
||||||
def __init__(self):
|
category="audio",
|
||||||
self.output_dir = folder_paths.get_temp_directory()
|
inputs=[
|
||||||
self.type = "temp"
|
IO.Audio.Input("audio"),
|
||||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, audio) -> IO.NodeOutput:
|
||||||
return {"required":
|
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
|
||||||
{"audio": ("AUDIO", ), },
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
save_flac = execute # TODO: remove
|
||||||
}
|
|
||||||
|
|
||||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||||
"""Convert audio to float 32 bits PCM format."""
|
"""Convert audio to float 32 bits PCM format."""
|
||||||
@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
|
|||||||
wav = f32_pcm(wav)
|
wav = f32_pcm(wav)
|
||||||
return wav, sr
|
return wav, sr
|
||||||
|
|
||||||
class LoadAudio:
|
class LoadAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
return IO.Schema(
|
||||||
|
node_id="LoadAudio",
|
||||||
|
display_name="Load Audio",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "audio"
|
@classmethod
|
||||||
|
def execute(cls, audio) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("AUDIO", )
|
|
||||||
FUNCTION = "load"
|
|
||||||
|
|
||||||
def load(self, audio):
|
|
||||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
waveform, sample_rate = load(audio_path)
|
waveform, sample_rate = load(audio_path)
|
||||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
return (audio, )
|
return IO.NodeOutput(audio)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, audio):
|
def fingerprint_inputs(cls, audio):
|
||||||
image_path = folder_paths.get_annotated_filepath(audio)
|
image_path = folder_paths.get_annotated_filepath(audio)
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
@ -343,46 +283,69 @@ class LoadAudio:
|
|||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def VALIDATE_INPUTS(s, audio):
|
def validate_inputs(cls, audio):
|
||||||
if not folder_paths.exists_annotated_filepath(audio):
|
if not folder_paths.exists_annotated_filepath(audio):
|
||||||
return "Invalid audio file: {}".format(audio)
|
return "Invalid audio file: {}".format(audio)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class RecordAudio:
|
load = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class RecordAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"audio": ("AUDIO_RECORD", {})}}
|
return IO.Schema(
|
||||||
|
node_id="RecordAudio",
|
||||||
|
display_name="Record Audio",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Custom("AUDIO_RECORD").Input("audio"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "audio"
|
@classmethod
|
||||||
|
def execute(cls, audio) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("AUDIO", )
|
|
||||||
FUNCTION = "load"
|
|
||||||
|
|
||||||
def load(self, audio):
|
|
||||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
|
|
||||||
waveform, sample_rate = load(audio_path)
|
waveform, sample_rate = load(audio_path)
|
||||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
return (audio, )
|
return IO.NodeOutput(audio)
|
||||||
|
|
||||||
|
load = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class TrimAudioDuration:
|
class TrimAudioDuration(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="TrimAudioDuration",
|
||||||
"audio": ("AUDIO",),
|
display_name="Trim Audio Duration",
|
||||||
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
|
description="Trim audio tensor into chosen time range.",
|
||||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
|
category="audio",
|
||||||
},
|
inputs=[
|
||||||
}
|
IO.Audio.Input("audio"),
|
||||||
|
IO.Float.Input(
|
||||||
|
"start_index",
|
||||||
|
default=0.0,
|
||||||
|
min=-0xffffffffffffffff,
|
||||||
|
max=0xffffffffffffffff,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"duration",
|
||||||
|
default=60.0,
|
||||||
|
min=0.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Duration in seconds",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "trim"
|
@classmethod
|
||||||
RETURN_TYPES = ("AUDIO",)
|
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
DESCRIPTION = "Trim audio tensor into chosen time range."
|
|
||||||
|
|
||||||
def trim(self, audio, start_index, duration):
|
|
||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
audio_length = waveform.shape[-1]
|
audio_length = waveform.shape[-1]
|
||||||
@ -399,23 +362,30 @@ class TrimAudioDuration:
|
|||||||
if start_frame >= end_frame:
|
if start_frame >= end_frame:
|
||||||
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
||||||
|
|
||||||
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
|
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
|
||||||
|
|
||||||
|
trim = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class SplitAudioChannels:
|
class SplitAudioChannels(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"audio": ("AUDIO",),
|
node_id="SplitAudioChannels",
|
||||||
}}
|
display_name="Split Audio Channels",
|
||||||
|
description="Separates the audio into left and right channels.",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Audio.Input("audio"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Audio.Output(display_name="left"),
|
||||||
|
IO.Audio.Output(display_name="right"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO", "AUDIO")
|
@classmethod
|
||||||
RETURN_NAMES = ("left", "right")
|
def execute(cls, audio) -> IO.NodeOutput:
|
||||||
FUNCTION = "separate"
|
|
||||||
CATEGORY = "audio"
|
|
||||||
DESCRIPTION = "Separates the audio into left and right channels."
|
|
||||||
|
|
||||||
def separate(self, audio):
|
|
||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
|
|
||||||
@ -425,7 +395,9 @@ class SplitAudioChannels:
|
|||||||
left_channel = waveform[..., 0:1, :]
|
left_channel = waveform[..., 0:1, :]
|
||||||
right_channel = waveform[..., 1:2, :]
|
right_channel = waveform[..., 1:2, :]
|
||||||
|
|
||||||
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
||||||
|
|
||||||
|
separate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
||||||
@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
|
|||||||
return waveform_1, waveform_2, output_sample_rate
|
return waveform_1, waveform_2, output_sample_rate
|
||||||
|
|
||||||
|
|
||||||
class AudioConcat:
|
class AudioConcat(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"audio1": ("AUDIO",),
|
node_id="AudioConcat",
|
||||||
"audio2": ("AUDIO",),
|
display_name="Audio Concat",
|
||||||
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
|
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||||
}}
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Audio.Input("audio1"),
|
||||||
|
IO.Audio.Input("audio2"),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"direction",
|
||||||
|
options=['after', 'before'],
|
||||||
|
default="after",
|
||||||
|
tooltip="Whether to append audio2 after or before audio1.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
@classmethod
|
||||||
FUNCTION = "concat"
|
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
|
|
||||||
|
|
||||||
def concat(self, audio1, audio2, direction):
|
|
||||||
waveform_1 = audio1["waveform"]
|
waveform_1 = audio1["waveform"]
|
||||||
waveform_2 = audio2["waveform"]
|
waveform_2 = audio2["waveform"]
|
||||||
sample_rate_1 = audio1["sample_rate"]
|
sample_rate_1 = audio1["sample_rate"]
|
||||||
@ -477,26 +457,33 @@ class AudioConcat:
|
|||||||
elif direction == 'before':
|
elif direction == 'before':
|
||||||
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
|
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
|
||||||
|
|
||||||
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
|
return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
|
||||||
|
|
||||||
|
concat = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class AudioMerge:
|
class AudioMerge(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="AudioMerge",
|
||||||
"audio1": ("AUDIO",),
|
display_name="Audio Merge",
|
||||||
"audio2": ("AUDIO",),
|
description="Combine two audio tracks by overlaying their waveforms.",
|
||||||
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
|
category="audio",
|
||||||
},
|
inputs=[
|
||||||
}
|
IO.Audio.Input("audio1"),
|
||||||
|
IO.Audio.Input("audio2"),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"merge_method",
|
||||||
|
options=["add", "mean", "subtract", "multiply"],
|
||||||
|
tooltip="The method used to combine the audio waveforms.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "merge"
|
@classmethod
|
||||||
RETURN_TYPES = ("AUDIO",)
|
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
|
|
||||||
|
|
||||||
def merge(self, audio1, audio2, merge_method):
|
|
||||||
waveform_1 = audio1["waveform"]
|
waveform_1 = audio1["waveform"]
|
||||||
waveform_2 = audio2["waveform"]
|
waveform_2 = audio2["waveform"]
|
||||||
sample_rate_1 = audio1["sample_rate"]
|
sample_rate_1 = audio1["sample_rate"]
|
||||||
@ -530,85 +517,110 @@ class AudioMerge:
|
|||||||
if max_val > 1.0:
|
if max_val > 1.0:
|
||||||
waveform = waveform / max_val
|
waveform = waveform / max_val
|
||||||
|
|
||||||
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
|
return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
|
||||||
|
|
||||||
|
merge = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class AudioAdjustVolume:
|
class AudioAdjustVolume(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"audio": ("AUDIO",),
|
node_id="AudioAdjustVolume",
|
||||||
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
|
display_name="Audio Adjust Volume",
|
||||||
}}
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Audio.Input("audio"),
|
||||||
|
IO.Int.Input(
|
||||||
|
"volume",
|
||||||
|
default=1,
|
||||||
|
min=-100,
|
||||||
|
max=100,
|
||||||
|
tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
@classmethod
|
||||||
FUNCTION = "adjust_volume"
|
def execute(cls, audio, volume) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
|
|
||||||
def adjust_volume(self, audio, volume):
|
|
||||||
if volume == 0:
|
if volume == 0:
|
||||||
return (audio,)
|
return IO.NodeOutput(audio)
|
||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
|
|
||||||
gain = 10 ** (volume / 20)
|
gain = 10 ** (volume / 20)
|
||||||
waveform = waveform * gain
|
waveform = waveform * gain
|
||||||
|
|
||||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
||||||
|
|
||||||
|
adjust_volume = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class EmptyAudio:
|
class EmptyAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
|
node_id="EmptyAudio",
|
||||||
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
|
display_name="Empty Audio",
|
||||||
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
|
category="audio",
|
||||||
}}
|
inputs=[
|
||||||
|
IO.Float.Input(
|
||||||
|
"duration",
|
||||||
|
default=60.0,
|
||||||
|
min=0.0,
|
||||||
|
max=0xffffffffffffffff,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Duration of the empty audio clip in seconds",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"sample_rate",
|
||||||
|
default=44100,
|
||||||
|
tooltip="Sample rate of the empty audio clip.",
|
||||||
|
min=1,
|
||||||
|
max=192000,
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"channels",
|
||||||
|
default=2,
|
||||||
|
min=1,
|
||||||
|
max=2,
|
||||||
|
tooltip="Number of audio channels (1 for mono, 2 for stereo).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
@classmethod
|
||||||
FUNCTION = "create_empty_audio"
|
def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
|
||||||
CATEGORY = "audio"
|
|
||||||
|
|
||||||
def create_empty_audio(self, duration, sample_rate, channels):
|
|
||||||
num_samples = int(round(duration * sample_rate))
|
num_samples = int(round(duration * sample_rate))
|
||||||
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
|
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
|
||||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
||||||
|
|
||||||
|
create_empty_audio = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class AudioExtension(ComfyExtension):
|
||||||
"EmptyLatentAudio": EmptyLatentAudio,
|
@override
|
||||||
"VAEEncodeAudio": VAEEncodeAudio,
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
"VAEDecodeAudio": VAEDecodeAudio,
|
return [
|
||||||
"SaveAudio": SaveAudio,
|
EmptyLatentAudio,
|
||||||
"SaveAudioMP3": SaveAudioMP3,
|
VAEEncodeAudio,
|
||||||
"SaveAudioOpus": SaveAudioOpus,
|
VAEDecodeAudio,
|
||||||
"LoadAudio": LoadAudio,
|
SaveAudio,
|
||||||
"PreviewAudio": PreviewAudio,
|
SaveAudioMP3,
|
||||||
"ConditioningStableAudio": ConditioningStableAudio,
|
SaveAudioOpus,
|
||||||
"RecordAudio": RecordAudio,
|
LoadAudio,
|
||||||
"TrimAudioDuration": TrimAudioDuration,
|
PreviewAudio,
|
||||||
"SplitAudioChannels": SplitAudioChannels,
|
ConditioningStableAudio,
|
||||||
"AudioConcat": AudioConcat,
|
RecordAudio,
|
||||||
"AudioMerge": AudioMerge,
|
TrimAudioDuration,
|
||||||
"AudioAdjustVolume": AudioAdjustVolume,
|
SplitAudioChannels,
|
||||||
"EmptyAudio": EmptyAudio,
|
AudioConcat,
|
||||||
}
|
AudioMerge,
|
||||||
|
AudioAdjustVolume,
|
||||||
|
EmptyAudio,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
async def comfy_entrypoint() -> AudioExtension:
|
||||||
"EmptyLatentAudio": "Empty Latent Audio",
|
return AudioExtension()
|
||||||
"VAEEncodeAudio": "VAE Encode Audio",
|
|
||||||
"VAEDecodeAudio": "VAE Decode Audio",
|
|
||||||
"PreviewAudio": "Preview Audio",
|
|
||||||
"LoadAudio": "Load Audio",
|
|
||||||
"SaveAudio": "Save Audio (FLAC)",
|
|
||||||
"SaveAudioMP3": "Save Audio (MP3)",
|
|
||||||
"SaveAudioOpus": "Save Audio (Opus)",
|
|
||||||
"RecordAudio": "Record Audio",
|
|
||||||
"TrimAudioDuration": "Trim Audio Duration",
|
|
||||||
"SplitAudioChannels": "Split Audio Channels",
|
|
||||||
"AudioConcat": "Audio Concat",
|
|
||||||
"AudioMerge": "Audio Merge",
|
|
||||||
"AudioAdjustVolume": "Audio Adjust Volume",
|
|
||||||
"EmptyAudio": "Empty Audio",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||||
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
|
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
|
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||||
@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||||
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||||
@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
context_overlap=context_overlap,
|
context_overlap=context_overlap,
|
||||||
context_stride=context_stride,
|
context_stride=context_stride,
|
||||||
closed_loop=closed_loop,
|
closed_loop=closed_loop,
|
||||||
dim=dim)
|
dim=dim,
|
||||||
|
freenoise=freenoise,
|
||||||
|
cond_retain_index_list=cond_retain_index_list,
|
||||||
|
split_conds_to_windows=split_conds_to_windows
|
||||||
|
)
|
||||||
# make memory usage calculation only take into account the context window latents
|
# make memory usage calculation only take into account the context window latents
|
||||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||||
|
if freenoise: # no other use for this wrapper at this time
|
||||||
|
comfy.context_windows.create_sampler_sample_wrapper(model)
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||||
@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
|||||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
|
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
|
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
]
|
]
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||||
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
|
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||||
|
|
||||||
|
|
||||||
class ContextWindowsExtension(ComfyExtension):
|
class ContextWindowsExtension(ComfyExtension):
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
|
|
||||||
def Fourier_filter(x, threshold, scale):
|
def Fourier_filter(x, threshold, scale):
|
||||||
# FFT
|
# FFT
|
||||||
@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
|
|||||||
return x_filtered.to(x.dtype)
|
return x_filtered.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class FreeU:
|
class FreeU(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return IO.Schema(
|
||||||
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="FreeU",
|
||||||
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="model_patches/unet",
|
||||||
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
IO.Model.Input("model"),
|
||||||
}}
|
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||||
|
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "model_patches/unet"
|
@classmethod
|
||||||
|
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
||||||
def patch(self, model, b1, b2, s1, s2):
|
|
||||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||||
on_cpu_devices = {}
|
on_cpu_devices = {}
|
||||||
@ -59,23 +66,31 @@ class FreeU:
|
|||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_output_block_patch(output_block_patch)
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
return (m, )
|
return IO.NodeOutput(m)
|
||||||
|
|
||||||
class FreeU_V2:
|
patch = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class FreeU_V2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return IO.Schema(
|
||||||
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="FreeU_V2",
|
||||||
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="model_patches/unet",
|
||||||
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
IO.Model.Input("model"),
|
||||||
}}
|
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||||
|
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "model_patches/unet"
|
@classmethod
|
||||||
|
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
||||||
def patch(self, model, b1, b2, s1, s2):
|
|
||||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||||
on_cpu_devices = {}
|
on_cpu_devices = {}
|
||||||
@ -105,9 +120,19 @@ class FreeU_V2:
|
|||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_output_block_patch(output_block_patch)
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
return (m, )
|
return IO.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
patch = execute # TODO: remove
|
||||||
"FreeU": FreeU,
|
|
||||||
"FreeU_V2": FreeU_V2,
|
|
||||||
}
|
class FreelunchExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
FreeU,
|
||||||
|
FreeU_V2,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> FreelunchExtension:
|
||||||
|
return FreelunchExtension()
|
||||||
|
|||||||
136
comfy_extras/nodes_kandinsky5.py
Normal file
136
comfy_extras/nodes_kandinsky5.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
import nodes
|
||||||
|
import node_helpers
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Kandinsky5ImageToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
|
||||||
|
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
cond_latent_out = {}
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
encoded = vae.encode(start_image[:, :, :, :3])
|
||||||
|
cond_latent_out["samples"] = encoded
|
||||||
|
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||||
|
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
|
||||||
|
|
||||||
|
|
||||||
|
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
|
||||||
|
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
|
||||||
|
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
|
||||||
|
|
||||||
|
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
|
||||||
|
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
normalized = (source - source_mean) / (source_std + 1e-8)
|
||||||
|
normalized = normalized * reference_std + reference_mean
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeVideoLatentStart(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="NormalizeVideoLatentStart",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("latent"),
|
||||||
|
io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
|
||||||
|
io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
|
||||||
|
if latent["samples"].shape[2] <= 1:
|
||||||
|
return io.NodeOutput(latent)
|
||||||
|
s = latent.copy()
|
||||||
|
samples = latent["samples"].clone()
|
||||||
|
|
||||||
|
first_frames = samples[:, :, :start_frame_count]
|
||||||
|
reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
|
||||||
|
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
||||||
|
|
||||||
|
samples[:, :, :start_frame_count] = normalized_first_frames
|
||||||
|
s["samples"] = samples
|
||||||
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncodeKandinsky5(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CLIPTextEncodeKandinsky5",
|
||||||
|
category="advanced/conditioning/kandinsky5",
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||||
|
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
|
||||||
|
tokens = clip.tokenize(clip_l)
|
||||||
|
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
|
||||||
|
|
||||||
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5Extension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
Kandinsky5ImageToVideo,
|
||||||
|
NormalizeVideoLatentStart,
|
||||||
|
CLIPTextEncodeKandinsky5,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> Kandinsky5Extension:
|
||||||
|
return Kandinsky5Extension()
|
||||||
@ -4,7 +4,7 @@ import torch
|
|||||||
import nodes
|
import nodes
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import logging
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||||
if latent.shape[1:] != target_shape[1:]:
|
if latent.shape[1:] != target_shape[1:]:
|
||||||
@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode):
|
|||||||
return luminance * sharpened
|
return luminance * sharpened
|
||||||
return io.NodeOutput(sharpen)
|
return io.NodeOutput(sharpen)
|
||||||
|
|
||||||
|
class ReplaceVideoLatentFrames(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ReplaceVideoLatentFrames",
|
||||||
|
category="latent/batch",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
|
||||||
|
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
|
||||||
|
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, destination, index, source=None) -> io.NodeOutput:
|
||||||
|
if source is None:
|
||||||
|
return io.NodeOutput(destination)
|
||||||
|
dest_frames = destination["samples"].shape[2]
|
||||||
|
source_frames = source["samples"].shape[2]
|
||||||
|
if index < 0:
|
||||||
|
index = dest_frames + index
|
||||||
|
if index > dest_frames:
|
||||||
|
logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
|
||||||
|
return io.NodeOutput(destination)
|
||||||
|
if index + source_frames > dest_frames:
|
||||||
|
logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
|
||||||
|
return io.NodeOutput(destination)
|
||||||
|
s = source.copy()
|
||||||
|
s_source = source["samples"]
|
||||||
|
s_destination = destination["samples"].clone()
|
||||||
|
s_destination[:, :, index:index + s_source.shape[2]] = s_source
|
||||||
|
s["samples"] = s_destination
|
||||||
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
class LatentExtension(ComfyExtension):
|
class LatentExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension):
|
|||||||
LatentApplyOperationCFG,
|
LatentApplyOperationCFG,
|
||||||
LatentOperationTonemapReinhard,
|
LatentOperationTonemapReinhard,
|
||||||
LatentOperationSharpen,
|
LatentOperationSharpen,
|
||||||
|
ReplaceVideoLatentFrames
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,22 +2,18 @@ import nodes
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from comfy.comfy_types import IO
|
from typing_extensions import override
|
||||||
from comfy_api.input_impl import VideoFromFile
|
from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
def normalize_path(path):
|
def normalize_path(path):
|
||||||
return path.replace('\\', '/')
|
return path.replace('\\', '/')
|
||||||
|
|
||||||
class Load3D():
|
class Load3D(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||||
|
|
||||||
os.makedirs(input_dir, exist_ok=True)
|
os.makedirs(input_dir, exist_ok=True)
|
||||||
@ -30,23 +26,29 @@ class Load3D():
|
|||||||
for file_path in input_path.rglob("*")
|
for file_path in input_path.rglob("*")
|
||||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
||||||
]
|
]
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="Load3D",
|
||||||
|
display_name="Load 3D & Animation",
|
||||||
|
category="3d",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
|
||||||
|
IO.Load3D.Input("image"),
|
||||||
|
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||||
|
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(display_name="image"),
|
||||||
|
IO.Mask.Output(display_name="mask"),
|
||||||
|
IO.String.Output(display_name="mesh_path"),
|
||||||
|
IO.Image.Output(display_name="normal"),
|
||||||
|
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||||
|
IO.Video.Output(display_name="recording_video"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
return {"required": {
|
@classmethod
|
||||||
"model_file": (sorted(files), {"file_upload": True}),
|
def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
|
||||||
"image": ("LOAD_3D", {}),
|
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
|
||||||
|
|
||||||
FUNCTION = "process"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
CATEGORY = "3d"
|
|
||||||
|
|
||||||
def process(self, model_file, image, **kwargs):
|
|
||||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||||
@ -61,58 +63,47 @@ class Load3D():
|
|||||||
if image['recording'] != "":
|
if image['recording'] != "":
|
||||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||||
|
|
||||||
video = VideoFromFile(recording_video_path)
|
video = InputImpl.VideoFromFile(recording_video_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
|
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
|
||||||
|
|
||||||
class Preview3D():
|
process = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class Preview3D(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return IO.Schema(
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
node_id="Preview3D",
|
||||||
},
|
display_name="Preview 3D & Animation",
|
||||||
"optional": {
|
category="3d",
|
||||||
"camera_info": ("LOAD3D_CAMERA", {}),
|
is_experimental=True,
|
||||||
"bg_image": ("IMAGE", {})
|
is_output_node=True,
|
||||||
}}
|
inputs=[
|
||||||
|
IO.String.Input("model_file", default="", multiline=False),
|
||||||
|
IO.Load3DCamera.Input("camera_info", optional=True),
|
||||||
|
IO.Image.Input("bg_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
@classmethod
|
||||||
RETURN_TYPES = ()
|
def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "3d"
|
|
||||||
|
|
||||||
FUNCTION = "process"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
|
||||||
camera_info = kwargs.get("camera_info", None)
|
camera_info = kwargs.get("camera_info", None)
|
||||||
bg_image = kwargs.get("bg_image", None)
|
bg_image = kwargs.get("bg_image", None)
|
||||||
|
return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
|
||||||
|
|
||||||
bg_image_path = None
|
process = execute # TODO: remove
|
||||||
if bg_image is not None:
|
|
||||||
|
|
||||||
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
|
||||||
img = Image.fromarray(img_array)
|
|
||||||
|
|
||||||
temp_dir = folder_paths.get_temp_directory()
|
class Load3DExtension(ComfyExtension):
|
||||||
filename = f"bg_{uuid.uuid4().hex}.png"
|
@override
|
||||||
bg_image_path = os.path.join(temp_dir, filename)
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
img.save(bg_image_path, compress_level=1)
|
return [
|
||||||
|
Load3D,
|
||||||
|
Preview3D,
|
||||||
|
]
|
||||||
|
|
||||||
bg_image_path = f"temp/{filename}"
|
|
||||||
|
|
||||||
return {
|
async def comfy_entrypoint() -> Load3DExtension:
|
||||||
"ui": {
|
return Load3DExtension()
|
||||||
"result": [model_file, camera_info, bg_image_path]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"Load3D": Load3D,
|
|
||||||
"Preview3D": Preview3D,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"Load3D": "Load 3D & Animation",
|
|
||||||
"Preview3D": "Preview 3D & Animation",
|
|
||||||
}
|
|
||||||
|
|||||||
155
comfy_extras/nodes_logic.py
Normal file
155
comfy_extras/nodes_logic.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_api.latest import _io
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("switch")
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComfySwitchNode",
|
||||||
|
display_name="Switch",
|
||||||
|
category="logic",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Boolean.Input("switch"),
|
||||||
|
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
||||||
|
io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(template=template, display_name="output"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_lazy_status(cls, switch, on_false=..., on_true=...):
|
||||||
|
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
|
||||||
|
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
||||||
|
|
||||||
|
# One of the inputs may be missing, in which case we need to evaluate the other input
|
||||||
|
if on_false is ...:
|
||||||
|
return ["on_true"]
|
||||||
|
if on_true is ...:
|
||||||
|
return ["on_false"]
|
||||||
|
# Normal lazy switch operation
|
||||||
|
if switch and on_true is None:
|
||||||
|
return ["on_true"]
|
||||||
|
if not switch and on_false is None:
|
||||||
|
return ["on_false"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, switch, on_false=..., on_true=...):
|
||||||
|
# This check happens before check_lazy_status(), so we can eliminate the case where
|
||||||
|
# both inputs are missing.
|
||||||
|
if on_false is ... and on_true is ...:
|
||||||
|
return "At least one of on_false or on_true must be connected to Switch node"
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
|
||||||
|
if on_true is ...:
|
||||||
|
return io.NodeOutput(on_false)
|
||||||
|
if on_false is ...:
|
||||||
|
return io.NodeOutput(on_true)
|
||||||
|
return io.NodeOutput(on_true if switch else on_false)
|
||||||
|
|
||||||
|
|
||||||
|
class DCTestNode(io.ComfyNode):
|
||||||
|
class DCValues(TypedDict):
|
||||||
|
combo: str
|
||||||
|
string: str
|
||||||
|
integer: int
|
||||||
|
image: io.Image.Type
|
||||||
|
subcombo: dict[str]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="DCTestNode",
|
||||||
|
display_name="DCTest",
|
||||||
|
category="logic",
|
||||||
|
is_output_node=True,
|
||||||
|
inputs=[_io.DynamicCombo.Input("combo", options=[
|
||||||
|
_io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||||
|
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
||||||
|
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||||
|
_io.DynamicCombo.Option("option4", [
|
||||||
|
_io.DynamicCombo.Input("subcombo", options=[
|
||||||
|
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
||||||
|
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
||||||
|
])
|
||||||
|
])]
|
||||||
|
)],
|
||||||
|
outputs=[io.AnyType.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, combo: DCValues) -> io.NodeOutput:
|
||||||
|
combo_val = combo["combo"]
|
||||||
|
if combo_val == "option1":
|
||||||
|
return io.NodeOutput(combo["string"])
|
||||||
|
elif combo_val == "option2":
|
||||||
|
return io.NodeOutput(combo["integer"])
|
||||||
|
elif combo_val == "option3":
|
||||||
|
return io.NodeOutput(combo["image"])
|
||||||
|
elif combo_val == "option4":
|
||||||
|
return io.NodeOutput(f"{combo['subcombo']}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid combo: {combo_val}")
|
||||||
|
|
||||||
|
|
||||||
|
class AutogrowNamesTestNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
|
||||||
|
return io.Schema(
|
||||||
|
node_id="AutogrowNamesTestNode",
|
||||||
|
display_name="AutogrowNamesTest",
|
||||||
|
category="logic",
|
||||||
|
inputs=[
|
||||||
|
_io.Autogrow.Input("autogrow", template=template)
|
||||||
|
],
|
||||||
|
outputs=[io.String.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
vals = list(autogrow.values())
|
||||||
|
combined = ",".join([str(x) for x in vals])
|
||||||
|
return io.NodeOutput(combined)
|
||||||
|
|
||||||
|
class AutogrowPrefixTestNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="AutogrowPrefixTestNode",
|
||||||
|
display_name="AutogrowPrefixTest",
|
||||||
|
category="logic",
|
||||||
|
inputs=[
|
||||||
|
_io.Autogrow.Input("autogrow", template=template)
|
||||||
|
],
|
||||||
|
outputs=[io.String.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
vals = list(autogrow.values())
|
||||||
|
combined = ",".join([str(x) for x in vals])
|
||||||
|
return io.NodeOutput(combined)
|
||||||
|
|
||||||
|
class LogicExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
# SwitchNode,
|
||||||
|
# DCTestNode,
|
||||||
|
# AutogrowNamesTestNode,
|
||||||
|
# AutogrowPrefixTestNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LogicExtension:
|
||||||
|
return LogicExtension()
|
||||||
@ -3,11 +3,10 @@ import scipy.ndimage
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import folder_paths
|
from typing_extensions import override
|
||||||
import random
|
from comfy_api.latest import ComfyExtension, IO, UI
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
from nodes import MAX_RESOLUTION
|
|
||||||
|
|
||||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||||
source = source.to(destination.device)
|
source = source.to(destination.device)
|
||||||
@ -46,202 +45,213 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
|||||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
class LatentCompositeMasked:
|
class LatentCompositeMasked(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="LatentCompositeMasked",
|
||||||
"destination": ("LATENT",),
|
category="latent",
|
||||||
"source": ("LATENT",),
|
inputs=[
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
IO.Latent.Input("destination"),
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
IO.Latent.Input("source"),
|
||||||
"resize_source": ("BOOLEAN", {"default": False}),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
},
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
"optional": {
|
IO.Boolean.Input("resize_source", default=False),
|
||||||
"mask": ("MASK",),
|
IO.Mask.Input("mask", optional=True),
|
||||||
}
|
],
|
||||||
}
|
outputs=[IO.Latent.Output()],
|
||||||
RETURN_TYPES = ("LATENT",)
|
)
|
||||||
FUNCTION = "composite"
|
|
||||||
|
|
||||||
CATEGORY = "latent"
|
@classmethod
|
||||||
|
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
|
||||||
output = destination.copy()
|
output = destination.copy()
|
||||||
destination = destination["samples"].clone()
|
destination = destination["samples"].clone()
|
||||||
source = source["samples"]
|
source = source["samples"]
|
||||||
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
|
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
|
||||||
return (output,)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
class ImageCompositeMasked:
|
composite = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCompositeMasked(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ImageCompositeMasked",
|
||||||
"destination": ("IMAGE",),
|
category="image",
|
||||||
"source": ("IMAGE",),
|
inputs=[
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Image.Input("destination"),
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Image.Input("source"),
|
||||||
"resize_source": ("BOOLEAN", {"default": False}),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
},
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
"optional": {
|
IO.Boolean.Input("resize_source", default=False),
|
||||||
"mask": ("MASK",),
|
IO.Mask.Input("mask", optional=True),
|
||||||
}
|
],
|
||||||
}
|
outputs=[IO.Image.Output()],
|
||||||
RETURN_TYPES = ("IMAGE",)
|
)
|
||||||
FUNCTION = "composite"
|
|
||||||
|
|
||||||
CATEGORY = "image"
|
@classmethod
|
||||||
|
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
|
||||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||||
destination = destination.clone().movedim(-1, 1)
|
destination = destination.clone().movedim(-1, 1)
|
||||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||||
return (output,)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
class MaskToImage:
|
composite = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class MaskToImage(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="MaskToImage",
|
||||||
"mask": ("MASK",),
|
display_name="Convert Mask to Image",
|
||||||
}
|
category="mask",
|
||||||
}
|
inputs=[
|
||||||
|
IO.Mask.Input("mask"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "mask_to_image"
|
|
||||||
|
|
||||||
def mask_to_image(self, mask):
|
|
||||||
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
return (result,)
|
return IO.NodeOutput(result)
|
||||||
|
|
||||||
class ImageToMask:
|
mask_to_image = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ImageToMask",
|
||||||
"image": ("IMAGE",),
|
display_name="Convert Image to Mask",
|
||||||
"channel": (["red", "green", "blue", "alpha"],),
|
category="mask",
|
||||||
}
|
inputs=[
|
||||||
}
|
IO.Image.Input("image"),
|
||||||
|
IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, image, channel) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
FUNCTION = "image_to_mask"
|
|
||||||
|
|
||||||
def image_to_mask(self, image, channel):
|
|
||||||
channels = ["red", "green", "blue", "alpha"]
|
channels = ["red", "green", "blue", "alpha"]
|
||||||
mask = image[:, :, :, channels.index(channel)]
|
mask = image[:, :, :, channels.index(channel)]
|
||||||
return (mask,)
|
return IO.NodeOutput(mask)
|
||||||
|
|
||||||
class ImageColorToMask:
|
image_to_mask = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageColorToMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ImageColorToMask",
|
||||||
"image": ("IMAGE",),
|
category="mask",
|
||||||
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
|
inputs=[
|
||||||
}
|
IO.Image.Input("image"),
|
||||||
}
|
IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, image, color) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
FUNCTION = "image_to_mask"
|
|
||||||
|
|
||||||
def image_to_mask(self, image, color):
|
|
||||||
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
|
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
|
||||||
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
|
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
|
||||||
mask = torch.where(temp == color, 1.0, 0).float()
|
mask = torch.where(temp == color, 1.0, 0).float()
|
||||||
return (mask,)
|
return IO.NodeOutput(mask)
|
||||||
|
|
||||||
class SolidMask:
|
image_to_mask = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class SolidMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="SolidMask",
|
||||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
category="mask",
|
||||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
}
|
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, value, width, height) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "solid"
|
|
||||||
|
|
||||||
def solid(self, value, width, height):
|
|
||||||
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
|
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
|
||||||
return (out,)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
class InvertMask:
|
solid = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class InvertMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="InvertMask",
|
||||||
"mask": ("MASK",),
|
category="mask",
|
||||||
}
|
inputs=[
|
||||||
}
|
IO.Mask.Input("mask"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "invert"
|
|
||||||
|
|
||||||
def invert(self, mask):
|
|
||||||
out = 1.0 - mask
|
out = 1.0 - mask
|
||||||
return (out,)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
class CropMask:
|
invert = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class CropMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="CropMask",
|
||||||
"mask": ("MASK",),
|
category="mask",
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
inputs=[
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Mask.Input("mask"),
|
||||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask, x, y, width, height) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "crop"
|
|
||||||
|
|
||||||
def crop(self, mask, x, y, width, height):
|
|
||||||
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
||||||
out = mask[:, y:y + height, x:x + width]
|
out = mask[:, y:y + height, x:x + width]
|
||||||
return (out,)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
class MaskComposite:
|
crop = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class MaskComposite(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="MaskComposite",
|
||||||
"destination": ("MASK",),
|
category="mask",
|
||||||
"source": ("MASK",),
|
inputs=[
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Mask.Input("destination"),
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Mask.Input("source"),
|
||||||
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "combine"
|
|
||||||
|
|
||||||
def combine(self, destination, source, x, y, operation):
|
|
||||||
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
|
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
|
||||||
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
|
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
|
||||||
|
|
||||||
@ -267,28 +277,29 @@ class MaskComposite:
|
|||||||
|
|
||||||
output = torch.clamp(output, 0.0, 1.0)
|
output = torch.clamp(output, 0.0, 1.0)
|
||||||
|
|
||||||
return (output,)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
class FeatherMask:
|
combine = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class FeatherMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="FeatherMask",
|
||||||
"mask": ("MASK",),
|
category="mask",
|
||||||
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
inputs=[
|
||||||
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Mask.Input("mask"),
|
||||||
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "feather"
|
|
||||||
|
|
||||||
def feather(self, mask, left, top, right, bottom):
|
|
||||||
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
||||||
|
|
||||||
left = min(left, output.shape[-1])
|
left = min(left, output.shape[-1])
|
||||||
@ -312,26 +323,28 @@ class FeatherMask:
|
|||||||
feather_rate = (y + 1) / bottom
|
feather_rate = (y + 1) / bottom
|
||||||
output[:, -y, :] *= feather_rate
|
output[:, -y, :] *= feather_rate
|
||||||
|
|
||||||
return (output,)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
class GrowMask:
|
feather = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class GrowMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="GrowMask",
|
||||||
"mask": ("MASK",),
|
display_name="Grow Mask",
|
||||||
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
|
category="mask",
|
||||||
"tapered_corners": ("BOOLEAN", {"default": True}),
|
inputs=[
|
||||||
},
|
IO.Mask.Input("mask"),
|
||||||
}
|
IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
IO.Boolean.Input("tapered_corners", default=True),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "expand_mask"
|
|
||||||
|
|
||||||
def expand_mask(self, mask, expand, tapered_corners):
|
|
||||||
c = 0 if tapered_corners else 1
|
c = 0 if tapered_corners else 1
|
||||||
kernel = np.array([[c, 1, c],
|
kernel = np.array([[c, 1, c],
|
||||||
[1, 1, 1],
|
[1, 1, 1],
|
||||||
@ -347,69 +360,74 @@ class GrowMask:
|
|||||||
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
||||||
output = torch.from_numpy(output)
|
output = torch.from_numpy(output)
|
||||||
out.append(output)
|
out.append(output)
|
||||||
return (torch.stack(out, dim=0),)
|
return IO.NodeOutput(torch.stack(out, dim=0))
|
||||||
|
|
||||||
class ThresholdMask:
|
expand_mask = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ThresholdMask",
|
||||||
"mask": ("MASK",),
|
category="mask",
|
||||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
inputs=[
|
||||||
}
|
IO.Mask.Input("mask"),
|
||||||
}
|
IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask, value) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
FUNCTION = "image_to_mask"
|
|
||||||
|
|
||||||
def image_to_mask(self, mask, value):
|
|
||||||
mask = (mask > value).float()
|
mask = (mask > value).float()
|
||||||
return (mask,)
|
return IO.NodeOutput(mask)
|
||||||
|
|
||||||
|
image_to_mask = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
# Mask Preview - original implement from
|
# Mask Preview - original implement from
|
||||||
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||||
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||||
class MaskPreview(nodes.SaveImage):
|
class MaskPreview(IO.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.output_dir = folder_paths.get_temp_directory()
|
def define_schema(cls):
|
||||||
self.type = "temp"
|
return IO.Schema(
|
||||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
node_id="MaskPreview",
|
||||||
self.compress_level = 4
|
display_name="Preview Mask",
|
||||||
|
category="mask",
|
||||||
|
description="Saves the input images to your ComfyUI output directory.",
|
||||||
|
inputs=[
|
||||||
|
IO.Mask.Input("mask"),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput:
|
||||||
return {
|
return IO.NodeOutput(ui=UI.PreviewMask(mask))
|
||||||
"required": {"mask": ("MASK",), },
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
|
||||||
}
|
|
||||||
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "mask"
|
|
||||||
|
|
||||||
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
|
||||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
|
||||||
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class MaskExtension(ComfyExtension):
|
||||||
"LatentCompositeMasked": LatentCompositeMasked,
|
@override
|
||||||
"ImageCompositeMasked": ImageCompositeMasked,
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
"MaskToImage": MaskToImage,
|
return [
|
||||||
"ImageToMask": ImageToMask,
|
LatentCompositeMasked,
|
||||||
"ImageColorToMask": ImageColorToMask,
|
ImageCompositeMasked,
|
||||||
"SolidMask": SolidMask,
|
MaskToImage,
|
||||||
"InvertMask": InvertMask,
|
ImageToMask,
|
||||||
"CropMask": CropMask,
|
ImageColorToMask,
|
||||||
"MaskComposite": MaskComposite,
|
SolidMask,
|
||||||
"FeatherMask": FeatherMask,
|
InvertMask,
|
||||||
"GrowMask": GrowMask,
|
CropMask,
|
||||||
"ThresholdMask": ThresholdMask,
|
MaskComposite,
|
||||||
"MaskPreview": MaskPreview
|
FeatherMask,
|
||||||
}
|
GrowMask,
|
||||||
|
ThresholdMask,
|
||||||
|
MaskPreview,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"ImageToMask": "Convert Image to Mask",
|
async def comfy_entrypoint() -> MaskExtension:
|
||||||
"MaskToImage": "Convert Mask to Image",
|
return MaskExtension()
|
||||||
}
|
|
||||||
|
|||||||
@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode):
|
|||||||
return io.NodeOutput(m)
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
# Sampling
|
|
||||||
"PatchModelAddDownscale": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
class ModelDownscaleExtension(ComfyExtension):
|
class ModelDownscaleExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import comfy.ops
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
|
import comfy.ldm.lumina.controlnet
|
||||||
|
|
||||||
|
|
||||||
class BlockWiseControlBlock(torch.nn.Module):
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
|
|||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
def z_image_convert(sd):
|
||||||
|
replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
|
||||||
|
".attention.norm_k.weight": ".attention.k_norm.weight",
|
||||||
|
".attention.norm_q.weight": ".attention.q_norm.weight",
|
||||||
|
".attention.to_out.0.weight": ".attention.out.weight"
|
||||||
|
}
|
||||||
|
|
||||||
|
out_sd = {}
|
||||||
|
for k in sorted(sd.keys()):
|
||||||
|
w = sd[k]
|
||||||
|
|
||||||
|
k_out = k
|
||||||
|
if k_out.endswith(".attention.to_k.weight"):
|
||||||
|
cc = [w]
|
||||||
|
continue
|
||||||
|
if k_out.endswith(".attention.to_q.weight"):
|
||||||
|
cc = [w] + cc
|
||||||
|
continue
|
||||||
|
if k_out.endswith(".attention.to_v.weight"):
|
||||||
|
cc = cc + [w]
|
||||||
|
w = torch.cat(cc, dim=0)
|
||||||
|
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
|
||||||
|
|
||||||
|
for r, rr in replace_keys.items():
|
||||||
|
k_out = k_out.replace(r, rr)
|
||||||
|
out_sd[k_out] = w
|
||||||
|
|
||||||
|
return out_sd
|
||||||
|
|
||||||
class ModelPatchLoader:
|
class ModelPatchLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -211,6 +241,9 @@ class ModelPatchLoader:
|
|||||||
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
||||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||||
|
sd = z_image_convert(sd)
|
||||||
|
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
@ -263,6 +296,69 @@ class DiffSynthCnetPatch:
|
|||||||
def models(self):
|
def models(self):
|
||||||
return [self.model_patch]
|
return [self.model_patch]
|
||||||
|
|
||||||
|
class ZImageControlPatch:
|
||||||
|
def __init__(self, model_patch, vae, image, strength):
|
||||||
|
self.model_patch = model_patch
|
||||||
|
self.vae = vae
|
||||||
|
self.image = image
|
||||||
|
self.strength = strength
|
||||||
|
self.encoded_image = self.encode_latent_cond(image)
|
||||||
|
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||||
|
self.temp_data = None
|
||||||
|
|
||||||
|
def encode_latent_cond(self, image):
|
||||||
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
def __call__(self, kwargs):
|
||||||
|
x = kwargs.get("x")
|
||||||
|
img = kwargs.get("img")
|
||||||
|
txt = kwargs.get("txt")
|
||||||
|
pe = kwargs.get("pe")
|
||||||
|
vec = kwargs.get("vec")
|
||||||
|
block_index = kwargs.get("block_index")
|
||||||
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
|
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
|
||||||
|
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||||
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
|
cnet_index = (block_index // 5)
|
||||||
|
cnet_index_float = (block_index / 5)
|
||||||
|
|
||||||
|
kwargs.pop("img") # we do ops in place
|
||||||
|
kwargs.pop("txt")
|
||||||
|
|
||||||
|
cnet_blocks = self.model_patch.model.n_control_layers
|
||||||
|
if cnet_index_float > (cnet_blocks - 1):
|
||||||
|
self.temp_data = None
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||||
|
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||||
|
|
||||||
|
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||||
|
next_layer = self.temp_data[0] + 1
|
||||||
|
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||||
|
|
||||||
|
if cnet_index_float == self.temp_data[0]:
|
||||||
|
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||||
|
if cnet_blocks == self.temp_data[0] + 1:
|
||||||
|
self.temp_data = None
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||||
|
self.temp_data = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
class QwenImageDiffsynthControlnet:
|
class QwenImageDiffsynthControlnet:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -289,6 +385,9 @@ class QwenImageDiffsynthControlnet:
|
|||||||
mask = mask.unsqueeze(2)
|
mask = mask.unsqueeze(2)
|
||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
|
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||||
|
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
|
||||||
|
else:
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
|
|||||||
@ -623,7 +623,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
if multi_res:
|
if multi_res:
|
||||||
# use first latent as dummy latent if multi_res
|
# use first latent as dummy latent if multi_res
|
||||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
|
||||||
guider.sample(
|
guider.sample(
|
||||||
noise.generate_noise({"samples": latents}),
|
noise.generate_noise({"samples": latents}),
|
||||||
latents,
|
latents,
|
||||||
|
|||||||
@ -88,7 +88,7 @@ class SaveVideo(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
|
def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||||
width, height = video.get_dimensions()
|
width, height = video.get_dimensions()
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||||
filename_prefix,
|
filename_prefix,
|
||||||
@ -108,7 +108,7 @@ class SaveVideo(io.ComfyNode):
|
|||||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||||
video.save_to(
|
video.save_to(
|
||||||
os.path.join(full_output_folder, file),
|
os.path.join(full_output_folder, file),
|
||||||
format=format,
|
format=VideoContainer(format),
|
||||||
codec=codec,
|
codec=codec,
|
||||||
metadata=saved_metadata
|
metadata=saved_metadata
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.75"
|
__version__ = "0.3.76"
|
||||||
|
|||||||
@ -63,9 +63,9 @@ def cuda_malloc_supported():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
if not args.cuda_malloc:
|
|
||||||
try:
|
|
||||||
version = ""
|
version = ""
|
||||||
|
|
||||||
|
try:
|
||||||
torch_spec = importlib.util.find_spec("torch")
|
torch_spec = importlib.util.find_spec("torch")
|
||||||
for folder in torch_spec.submodule_search_locations:
|
for folder in torch_spec.submodule_search_locations:
|
||||||
ver_file = os.path.join(folder, "version.py")
|
ver_file = os.path.join(folder, "version.py")
|
||||||
@ -74,7 +74,11 @@ if not args.cuda_malloc:
|
|||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
version = module.__version__
|
version = module.__version__
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not args.cuda_malloc:
|
||||||
|
try:
|
||||||
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
||||||
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
|
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
|
||||||
args.cuda_malloc = cuda_malloc_supported()
|
args.cuda_malloc = cuda_malloc_supported()
|
||||||
@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc:
|
|||||||
env_var += ",backend:cudaMallocAsync"
|
env_var += ",backend:cudaMallocAsync"
|
||||||
|
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||||
|
|
||||||
|
def get_torch_version_noimport():
|
||||||
|
return str(version)
|
||||||
|
|||||||
40
execution.py
40
execution.py
@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input
|
|||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io
|
from comfy_api.latest import io, _io
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -76,7 +76,7 @@ class IsChangedCache:
|
|||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||||
@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
|||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||||
|
v3_data: io.V3Data = {}
|
||||||
if is_v3:
|
if is_v3:
|
||||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
||||||
else:
|
else:
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||||
if h[x] == "API_KEY_COMFY_ORG":
|
if h[x] == "API_KEY_COMFY_ORG":
|
||||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||||
return input_data_all, missing_keys, hidden_inputs_v3
|
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||||
|
return input_data_all, missing_keys, v3_data
|
||||||
|
|
||||||
map_node_over_list = None #Don't hook this please
|
map_node_over_list = None #Don't hook this please
|
||||||
|
|
||||||
@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
|
|||||||
raise exc
|
raise exc
|
||||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||||
|
|
||||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||||
|
|
||||||
@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
if is_class(obj):
|
if is_class(obj):
|
||||||
type_obj = obj
|
type_obj = obj
|
||||||
obj.VALIDATE_CLASS()
|
obj.VALIDATE_CLASS()
|
||||||
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
||||||
# otherwise, use class instance to populate/reuse some fields
|
# otherwise, use class instance to populate/reuse some fields
|
||||||
else:
|
else:
|
||||||
type_obj = type(obj)
|
type_obj = type(obj)
|
||||||
type_obj.VALIDATE_CLASS()
|
type_obj.VALIDATE_CLASS()
|
||||||
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
||||||
f = make_locked_method_func(type_obj, func, class_clone)
|
f = make_locked_method_func(type_obj, func, class_clone)
|
||||||
|
# in case of dynamic inputs, restructure inputs to expected nested dict
|
||||||
|
if v3_data is not None:
|
||||||
|
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||||
# V1
|
# V1
|
||||||
else:
|
else:
|
||||||
f = getattr(obj, func)
|
f = getattr(obj, func)
|
||||||
@ -320,8 +325,8 @@ def merge_result_data(results, obj):
|
|||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||||
if has_pending_task:
|
if has_pending_task:
|
||||||
return return_values, {}, False, has_pending_task
|
return return_values, {}, False, has_pending_task
|
||||||
@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
get_progress_state().start_progress(unique_id)
|
get_progress_state().start_progress(unique_id)
|
||||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
else:
|
else:
|
||||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||||
if lazy_status_present:
|
if lazy_status_present:
|
||||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
|
||||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
def pre_execute_cb(call_index):
|
def pre_execute_cb(call_index):
|
||||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||||
if has_pending_tasks:
|
if has_pending_tasks:
|
||||||
pending_async_nodes[unique_id] = output_data
|
pending_async_nodes[unique_id] = output_data
|
||||||
unblock = execution_list.add_external_block(unique_id)
|
unblock = execution_list.add_external_block(unique_id)
|
||||||
@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
|
||||||
class_inputs = obj_class.INPUT_TYPES()
|
|
||||||
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
|
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
validate_function_inputs = []
|
validate_function_inputs = []
|
||||||
validate_has_kwargs = False
|
validate_has_kwargs = False
|
||||||
if issubclass(obj_class, _ComfyNodeInternal):
|
if issubclass(obj_class, _ComfyNodeInternal):
|
||||||
|
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
||||||
validate_function_name = "validate_inputs"
|
validate_function_name = "validate_inputs"
|
||||||
validate_function = first_real_override(obj_class, validate_function_name)
|
validate_function = first_real_override(obj_class, validate_function_name)
|
||||||
else:
|
else:
|
||||||
|
class_inputs = obj_class.INPUT_TYPES()
|
||||||
validate_function_name = "VALIDATE_INPUTS"
|
validate_function_name = "VALIDATE_INPUTS"
|
||||||
validate_function = getattr(obj_class, validate_function_name, None)
|
validate_function = getattr(obj_class, validate_function_name, None)
|
||||||
if validate_function is not None:
|
if validate_function is not None:
|
||||||
@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
validate_has_kwargs = argspec.varkw is not None
|
validate_has_kwargs = argspec.varkw is not None
|
||||||
received_types = {}
|
received_types = {}
|
||||||
|
|
||||||
|
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
|
||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||||
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
|
||||||
input_filtered = {}
|
input_filtered = {}
|
||||||
for x in input_data_all:
|
for x in input_data_all:
|
||||||
if x in validate_function_inputs or validate_has_kwargs:
|
if x in validate_function_inputs or validate_has_kwargs:
|
||||||
@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
|||||||
if 'input_types' in validate_function_inputs:
|
if 'input_types' in validate_function_inputs:
|
||||||
input_filtered['input_types'] = [received_types]
|
input_filtered['input_types'] = [received_types]
|
||||||
|
|
||||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
|
||||||
ret = await resolve_map_node_over_list_results(ret)
|
ret = await resolve_map_node_over_list_results(ret)
|
||||||
for x in input_filtered:
|
for x in input_filtered:
|
||||||
for i, r in enumerate(ret):
|
for i, r in enumerate(ret):
|
||||||
|
|||||||
33
main.py
33
main.py
@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state
|
|||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||||
@ -22,6 +23,23 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_comfyui_manager_unavailable():
|
||||||
|
if not args.windows_standalone_build:
|
||||||
|
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
|
||||||
|
args.enable_manager = False
|
||||||
|
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
if importlib.util.find_spec("comfyui_manager"):
|
||||||
|
import comfyui_manager
|
||||||
|
|
||||||
|
if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'):
|
||||||
|
handle_comfyui_manager_unavailable()
|
||||||
|
else:
|
||||||
|
handle_comfyui_manager_unavailable()
|
||||||
|
|
||||||
|
|
||||||
def apply_custom_paths():
|
def apply_custom_paths():
|
||||||
# extra model paths
|
# extra model paths
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
@ -84,6 +102,11 @@ def execute_prestartup_script():
|
|||||||
|
|
||||||
for possible_module in possible_modules:
|
for possible_module in possible_modules:
|
||||||
module_path = os.path.join(custom_node_path, possible_module)
|
module_path = os.path.join(custom_node_path, possible_module)
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
if comfyui_manager.should_be_disabled(module_path):
|
||||||
|
continue
|
||||||
|
|
||||||
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
|
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -106,6 +129,10 @@ def execute_prestartup_script():
|
|||||||
logging.info("")
|
logging.info("")
|
||||||
|
|
||||||
apply_custom_paths()
|
apply_custom_paths()
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
comfyui_manager.prestartup()
|
||||||
|
|
||||||
execute_prestartup_script()
|
execute_prestartup_script()
|
||||||
|
|
||||||
|
|
||||||
@ -145,6 +172,9 @@ if __name__ == "__main__":
|
|||||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||||
|
|
||||||
import cuda_malloc
|
import cuda_malloc
|
||||||
|
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
||||||
|
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
||||||
|
|
||||||
|
|
||||||
if 'torch' in sys.modules:
|
if 'torch' in sys.modules:
|
||||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||||
@ -328,6 +358,9 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
asyncio.set_event_loop(asyncio_loop)
|
asyncio.set_event_loop(asyncio_loop)
|
||||||
prompt_server = server.PromptServer(asyncio_loop)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
|
|
||||||
|
if args.enable_manager and not args.disable_manager_ui:
|
||||||
|
comfyui_manager.start()
|
||||||
|
|
||||||
hook_breaker_ac10a0.save_functions()
|
hook_breaker_ac10a0.save_functions()
|
||||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||||
|
|||||||
1
manager_requirements.txt
Normal file
1
manager_requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
comfyui_manager==4.0.3b4
|
||||||
15
nodes.py
15
nodes.py
@ -43,6 +43,9 @@ import folder_paths
|
|||||||
import latent_preview
|
import latent_preview
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
import comfyui_manager
|
||||||
|
|
||||||
def before_node_execution():
|
def before_node_execution():
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
@ -939,7 +942,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -967,7 +970,7 @@ class DualCLIPLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -2243,6 +2246,12 @@ async def init_external_custom_nodes():
|
|||||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
if comfyui_manager.should_be_disabled(module_path):
|
||||||
|
logging.info(f"Blocked by policy: {module_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
time_before = time.perf_counter()
|
time_before = time.perf_counter()
|
||||||
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
@ -2346,7 +2355,9 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_easycache.py",
|
"nodes_easycache.py",
|
||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
"nodes_rope.py",
|
"nodes_rope.py",
|
||||||
|
"nodes_logic.py",
|
||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
|
"nodes_kandinsky5.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.75"
|
version = "0.3.76"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.32.9
|
comfyui-frontend-package==1.33.10
|
||||||
comfyui-workflow-templates==0.7.25
|
comfyui-workflow-templates==0.7.25
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
|
|||||||
10
server.py
10
server.py
@ -44,6 +44,9 @@ from protocol import BinaryEventTypes
|
|||||||
# Import cache control middleware
|
# Import cache control middleware
|
||||||
from middleware.cache_middleware import cache_control
|
from middleware.cache_middleware import cache_control
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
import comfyui_manager
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
await function(message)
|
await function(message)
|
||||||
@ -95,7 +98,7 @@ def create_cors_middleware(allowed_origin: str):
|
|||||||
response = await handler(request)
|
response = await handler(request)
|
||||||
|
|
||||||
response.headers['Access-Control-Allow-Origin'] = allowed_origin
|
response.headers['Access-Control-Allow-Origin'] = allowed_origin
|
||||||
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
|
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH'
|
||||||
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
||||||
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||||
return response
|
return response
|
||||||
@ -212,6 +215,9 @@ class PromptServer():
|
|||||||
if args.disable_api_nodes:
|
if args.disable_api_nodes:
|
||||||
middlewares.append(create_block_external_middleware())
|
middlewares.append(create_block_external_middleware())
|
||||||
|
|
||||||
|
if args.enable_manager:
|
||||||
|
middlewares.append(comfyui_manager.create_middleware())
|
||||||
|
|
||||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
@ -599,7 +605,7 @@ class PromptServer():
|
|||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": sys.platform,
|
||||||
"ram_total": ram_total,
|
"ram_total": ram_total,
|
||||||
"ram_free": ram_free,
|
"ram_free": ram_free,
|
||||||
"comfyui_version": __version__,
|
"comfyui_version": __version__,
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
# Add comfy to path
|
# Add comfy to path
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
@ -15,6 +16,7 @@ if not has_gpu():
|
|||||||
|
|
||||||
from comfy import ops
|
from comfy import ops
|
||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
class SimpleModel(torch.nn.Module):
|
class SimpleModel(torch.nn.Module):
|
||||||
@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
@ -115,6 +118,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
with torch.inference_mode():
|
||||||
output = model(input_tensor)
|
output = model(input_tensor)
|
||||||
|
|
||||||
self.assertEqual(output.shape, (5, 40))
|
self.assertEqual(output.shape, (5, 40))
|
||||||
@ -141,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
model.load_state_dict(state_dict1, strict=False)
|
model.load_state_dict(state_dict1, strict=False)
|
||||||
|
|
||||||
# Save state dict
|
# Save state dict
|
||||||
@ -178,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Add a weight function (simulating LoRA)
|
# Add a weight function (simulating LoRA)
|
||||||
@ -215,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
|
|
||||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user