Merge branch 'comfyanonymous:master' into master

This commit is contained in:
Bahadir Ciloglu 2025-11-01 14:36:58 +03:00 committed by GitHub
commit b35cce6674
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 2567 additions and 1835 deletions

View File

@ -1,2 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause pause

View File

@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause pause

View File

@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause pause

View File

@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -144,6 +145,7 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult" Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops" CublasOps = "cublas_ops"
AutoTune = "autotune" AutoTune = "autotune"
PinnedMem = "pinned_memory"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))

View File

@ -310,11 +310,13 @@ class ControlLoraOps:
self.bias = None self.bias = None
def forward(self, input): def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input) weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None: if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else: else:
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__( def __init__(
@ -350,12 +352,13 @@ class ControlLoraOps:
def forward(self, input): def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input) weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None: if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else: else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class ControlLora(ControlNet): class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

View File

@ -522,7 +522,7 @@ class NextDiT(nn.Module):
max_cap_len = max(l_effective_cap_len) max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len) max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
for i in range(bsz): for i in range(bsz):
cap_len = l_effective_cap_len[i] cap_len = l_effective_cap_len[i]
@ -531,10 +531,22 @@ class NextDiT(nn.Module):
H_tokens, W_tokens = H // pH, W // pW H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

View File

@ -588,7 +588,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None): def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@ -601,10 +601,22 @@ class WanModel(torch.nn.Module):
if steps_w is None: if steps_w is None:
steps_w = w_len steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2) freqs = self.rope_embedder(img_ids).movedim(1, 2)
@ -630,7 +642,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs: if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1 t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):

View File

@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False) fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@ -333,6 +333,14 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None: if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION: if self.model_type == ModelType.V_PREDICTION:

View File

@ -6,6 +6,20 @@ import math
import logging import logging
import torch import torch
def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None
def count_blocks(state_dict_keys, prefix_string): def count_blocks(state_dict_keys, prefix_string):
count = 0 count = 0
while True: while True:
@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else: else:
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
return model_config return model_config
def unet_prefix_from_state_dict(state_dict): def unet_prefix_from_state_dict(state_dict):

View File

@ -1013,6 +1013,16 @@ if args.async_offload:
NUM_STREAMS = 2 NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {} stream_counters = {}
def get_offload_stream(device): def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0) stream_counter = stream_counters.get(device, 0)
@ -1021,21 +1031,17 @@ def get_offload_stream(device):
if device in STREAMS: if device in STREAMS:
ss = STREAMS[device] ss = STREAMS[device]
s = ss[stream_counter] #Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss) stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return ss[stream_counter]
elif is_device_cuda(device): elif is_device_cuda(device):
ss = [] ss = []
for k in range(NUM_STREAMS): for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0)) ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss STREAMS[device] = ss
s = ss[stream_counter] s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return s
elif is_device_xpu(device): elif is_device_xpu(device):
@ -1044,18 +1050,14 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0)) ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss STREAMS[device] = ss
s = ss[stream_counter] s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return s
return None return None
def sync_stream(device, stream): def sync_stream(device, stream):
if stream is None: if stream is None or current_stream(device) is None:
return return
if is_device_cuda(device): current_stream(device).wait_stream(stream)
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device: if device is None or weight.device == device:
@ -1080,6 +1082,36 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device) non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
def pin_memory(tensor):
if PerformanceFeature.PinnedMem not in args.fast:
return False
if not is_nvidia():
return False
if not is_device_cpu(tensor.device):
return False
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0:
return True
return False
def unpin_memory(tensor):
if PerformanceFeature.PinnedMem not in args.fast:
return False
if not is_nvidia():
return False
if not is_device_cpu(tensor.device):
return False
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0:
return True
return False
def sage_attention_enabled(): def sage_attention_enabled():
return args.use_sage_attention return args.use_sage_attention

View File

@ -238,6 +238,7 @@ class ModelPatcher:
self.force_cast_weights = False self.force_cast_weights = False
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
self.parent = None self.parent = None
self.pinned = set()
self.attachments: dict[str] = {} self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {} self.additional_models: dict[str, list[ModelPatcher]] = {}
@ -275,6 +276,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self): def loaded_size(self):
return self.model.model_loaded_weight_memory return self.model.model_loaded_weight_memory
@ -450,6 +454,19 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch): def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input") self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj): def add_object_patch(self, name, obj):
self.object_patches[name] = obj self.object_patches[name] = obj
@ -618,6 +635,21 @@ class ModelPatcher:
else: else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if comfy.model_management.pin_memory(weight):
self.pinned.add(key)
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self): def _load_list(self):
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
@ -639,9 +671,11 @@ class ModelPatcher:
mem_counter = 0 mem_counter = 0
patch_counter = 0 patch_counter = 0
lowvram_counter = 0 lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list() loading = self._load_list()
load_completely = [] load_completely = []
offloaded = []
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
n = x[1] n = x[1]
@ -658,6 +692,7 @@ class ModelPatcher:
if mem_counter + module_mem >= lowvram_model_memory: if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True lowvram_weight = True
lowvram_counter += 1 lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue continue
@ -683,6 +718,7 @@ class ModelPatcher:
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True
offloaded.append((module_mem, n, m, params))
else: else:
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m) wipe_lowvram_weight(m)
@ -713,7 +749,9 @@ class ModelPatcher:
continue continue
for param in params: for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True m.comfy_patched_weights = True
@ -721,11 +759,17 @@ class ModelPatcher:
for x in load_completely: for x in load_completely:
x[2].to(device_to) x[2].to(device_to)
for x in offloaded:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
else: else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
if full_load: if full_load:
self.model.to(device_to) self.model.to(device_to)
@ -762,6 +806,7 @@ class ModelPatcher:
self.eject_model() self.eject_model()
if unpatch_weights: if unpatch_weights:
self.unpatch_hooks() self.unpatch_hooks()
self.unpin_all_weights()
if self.model.model_lowvram: if self.model.model_lowvram:
for m in self.model.modules(): for m in self.model.modules():
move_weight_functions(m, device_to) move_weight_functions(m, device_to)
@ -857,6 +902,9 @@ class ModelPatcher:
memory_freed += module_mem memory_freed += module_mem
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed self.model.model_loaded_weight_memory -= memory_freed
@ -1259,5 +1307,6 @@ class ModelPatcher:
self.clear_cached_hook_weights() self.clear_cached_hook_weights()
def __del__(self): def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False) self.detach(unpatch_all=False)

View File

@ -70,8 +70,12 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@torch.compiler.disable() @torch.compiler.disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None: if input is not None:
if dtype is None: if dtype is None:
dtype = input.dtype dtype = input.dtype
@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None: if device is None:
device = input.device device = input.device
offload_stream = comfy.model_management.get_offload_stream(device) if offloadable:
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None: if offload_stream is not None:
wf_context = offload_stream wf_context = offload_stream
else: else:
@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
weight = f(weight) weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream) comfy.model_management.sync_stream(device, offload_stream)
return weight, bias if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp: class CastWeightBiasOp:
comfy_cast_weights = False comfy_cast_weights = False
@ -118,8 +143,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -133,8 +160,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -148,8 +177,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -172,8 +203,10 @@ class disable_weight_init:
return super()._conv_forward(input, weight, bias, *args, **kwargs) return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -187,8 +220,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -203,11 +238,14 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else: else:
weight = None weight = None
bias = None bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -223,11 +261,15 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else: else:
weight = None weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated bias = None
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -246,10 +288,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.conv_transpose2d( x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -268,10 +312,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.conv_transpose1d( x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -289,8 +335,11 @@ class disable_weight_init:
output_dtype = out_dtype output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -344,20 +393,18 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input): def fp8_linear(self, input):
"""
Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
"""
dtype = self.weight.dtype dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]: if dtype not in [torch.float8_e4m3fn]:
return None return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) if input.ndim == 3 or input.ndim == 2:
w = w.t() w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight scale_weight = self.scale_weight
scale_input = self.scale_input scale_input = self.scale_input
@ -369,23 +416,20 @@ def fp8_linear(self, input):
if scale_input is None: if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input) input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
if bias is not None: # Wrap weight in QuantizedTensor - this enables unified dispatch
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
else: layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
if isinstance(o, tuple): uncast_bias_weight(self, w, bias, offload_stream)
o = o[0] return o
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None return None
@ -405,8 +449,10 @@ class fp8_ops(manual_cast):
except Exception as e: except Exception as e:
logging.info("Exception during fp8 op: {}".format(e)) logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@ -434,12 +480,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None: if out is not None:
return out return out
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else: else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs): def convert_weight(self, weight, inplace=False, **kwargs):
if inplace: if inplace:
@ -478,7 +526,130 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor
QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": {
"dtype": torch.float8_e4m3fn,
"layout_type": "TensorCoreFP8Layout",
"parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
}
}
}
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
_compute_dtype = torch.bfloat16
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
mixin = QUANT_FORMAT_MIXINS[quant_format]
self.layout_type = mixin["layout_type"]
scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
requires_grad=False
)
for param_name, param_value in mixin["parameters"].items():
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
MixedPrecisionOps._compute_dtype = compute_dtype
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

455
comfy/quant_ops.py Normal file
View File

@ -0,0 +1,455 @@
import torch
import logging
from typing import Tuple, Dict
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous()
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type.__name__
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata)
qt_dest._layout_type = src._layout_type
qt_dest._layout_params = _copy_layout_params(src._layout_params)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return qdata, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@ -143,6 +143,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model) return self.patcher.add_patches(patches, strength_patch, strength_model)
@ -293,6 +296,7 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False self.disable_offload = False
self.not_video = False self.not_video = False
self.size = None
self.downscale_index_formula = None self.downscale_index_formula = None
self.upscale_index_formula = None self.upscale_index_formula = None
@ -595,6 +599,16 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self): def throw_exception_if_invalid(self):
if self.first_stage_model is None: if self.first_stage_model is None:
@ -1262,7 +1276,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
""" """
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@ -1296,7 +1310,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd) weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "") model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None: if model_config is not None:
new_sd = sd new_sd = sd
@ -1330,7 +1344,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
else: else:
unet_dtype = dtype unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) if model_config.layer_quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False): if model_options.get("fp8_optimizations", False):
@ -1346,8 +1363,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}): def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path) sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options) model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))

View File

@ -50,6 +50,7 @@ class BASE:
manual_cast_dtype = None manual_cast_dtype = None
custom_operations = None custom_operations = None
scaled_fp8 = None scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
optimizations = {"fp8": False} optimizations = {"fp8": False}
@classmethod @classmethod

View File

@ -1,25 +1,13 @@
from __future__ import annotations from __future__ import annotations
import aiohttp import aiohttp
import mimetypes import mimetypes
from typing import Optional, Union from typing import Union
from comfy.utils import common_upscale
from comfy_api_nodes.apis.client import (
ApiClient,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
UploadRequest,
UploadResponse,
)
from server import PromptServer from server import PromptServer
from comfy.cli_args import args
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import math
import base64 import base64
from .util import tensor_to_bytesio, bytesio_to_image_tensor
from io import BytesIO from io import BytesIO
@ -69,90 +57,6 @@ async def validate_and_cast_response(
return torch.stack(image_tensors, dim=0) return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response_content))
def text_filepath_to_base64_string(filepath: str) -> str: def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string.""" """Converts a text file to a base64 string."""
with open(filepath, "rb") as f: with open(filepath, "rb") as f:
@ -167,95 +71,3 @@ def text_filepath_to_data_uri(filepath: str) -> str:
if mime_type is None: if mime_type is None:
mime_type = "application/octet-stream" mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}" return f"data:{mime_type};base64,{base64_string}"
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
Args:
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
"""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask

View File

@ -0,0 +1,120 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class MinimaxBaseResponse(BaseModel):
status_code: int = Field(
...,
description='Status code. 0 indicates success, other values indicate errors.',
)
status_msg: str = Field(
..., description='Specific error details or success message.'
)
class File(BaseModel):
bytes: Optional[int] = Field(None, description='File size in bytes')
created_at: Optional[int] = Field(
None, description='Unix timestamp when the file was created, in seconds'
)
download_url: Optional[str] = Field(
None, description='The URL to download the video'
)
backup_download_url: Optional[str] = Field(
None, description='The backup URL to download the video'
)
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
filename: Optional[str] = Field(None, description='The name of the file')
purpose: Optional[str] = Field(None, description='The purpose of using the file')
class MinimaxFileRetrieveResponse(BaseModel):
base_resp: MinimaxBaseResponse
file: File
class MiniMaxModel(str, Enum):
T2V_01_Director = 'T2V-01-Director'
I2V_01_Director = 'I2V-01-Director'
S2V_01 = 'S2V-01'
I2V_01 = 'I2V-01'
I2V_01_live = 'I2V-01-live'
T2V_01 = 'T2V-01'
Hailuo_02 = 'MiniMax-Hailuo-02'
class Status6(str, Enum):
Queueing = 'Queueing'
Preparing = 'Preparing'
Processing = 'Processing'
Success = 'Success'
Fail = 'Fail'
class MinimaxTaskResultResponse(BaseModel):
base_resp: MinimaxBaseResponse
file_id: Optional[str] = Field(
None,
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
)
status: Status6 = Field(
...,
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
)
task_id: str = Field(..., description='The task ID being queried.')
class SubjectReferenceItem(BaseModel):
image: Optional[str] = Field(
None, description='URL or base64 encoding of the subject reference image.'
)
mask: Optional[str] = Field(
None,
description='URL or base64 encoding of the mask for the subject reference image.',
)
class MinimaxVideoGenerationRequest(BaseModel):
callback_url: Optional[str] = Field(
None,
description='Optional. URL to receive real-time status updates about the video generation task.',
)
first_frame_image: Optional[str] = Field(
None,
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
)
model: MiniMaxModel = Field(
...,
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
)
prompt: Optional[str] = Field(
None,
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
max_length=2000,
)
prompt_optimizer: Optional[bool] = Field(
True,
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
)
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
None,
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
)
duration: Optional[int] = Field(
None,
description="The length of the output video in seconds."
)
resolution: Optional[str] = Field(
None,
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
)
class MinimaxVideoGenerationResponse(BaseModel):
base_resp: MinimaxBaseResponse
task_id: str = Field(
..., description='The task ID for the asynchronous video generation task.'
)

View File

@ -5,10 +5,6 @@ import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
validate_aspect_ratio,
)
from comfy_api_nodes.apis.bfl_api import ( from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
@ -23,8 +19,10 @@ from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
download_url_to_image_tensor, download_url_to_image_tensor,
poll_op, poll_op,
resize_mask_to_image,
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
validate_aspect_ratio_string,
validate_string, validate_string,
) )
@ -43,11 +41,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
""" """
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
@ -112,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
@classmethod @classmethod
def validate_inputs(cls, aspect_ratio: str): def validate_inputs(cls, aspect_ratio: str):
try: validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True return True
@classmethod @classmethod
@ -145,13 +129,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt=prompt, prompt=prompt,
prompt_upsampling=prompt_upsampling, prompt_upsampling=prompt_upsampling,
seed=seed, seed=seed,
aspect_ratio=validate_aspect_ratio( aspect_ratio=aspect_ratio,
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
),
raw=raw, raw=raw,
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
@ -180,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode):
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
""" """
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
@ -261,13 +234,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
aspect_ratio = validate_aspect_ratio( validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
if input_image is None: if input_image is None:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
initial_response = await sync_op( initial_response = await sync_op(

View File

@ -17,7 +17,7 @@ from comfy_api_nodes.util import (
poll_op, poll_op,
sync_op, sync_op,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_image_aspect_ratio_range, validate_image_aspect_ratio,
validate_image_dimensions, validate_image_dimensions,
validate_string, validate_string,
) )
@ -403,7 +403,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.") raise ValueError("Exactly one input image is required.")
validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) validate_image_aspect_ratio(image, (1, 3), (3, 1))
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
payload = Image2ImageTaskCreationRequest( payload = Image2ImageTaskCreationRequest(
model=model, model=model,
@ -565,7 +565,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
reference_images_urls = [] reference_images_urls = []
if n_input_images: if n_input_images:
for i in image: for i in image:
validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) validate_image_aspect_ratio(i, (1, 3), (3, 1))
reference_images_urls = await upload_images_to_comfyapi( reference_images_urls = await upload_images_to_comfyapi(
cls, cls,
image, image,
@ -798,7 +798,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
prompt = ( prompt = (
@ -923,7 +923,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
for i in (first_frame, last_frame): for i in (first_frame, last_frame):
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
@ -1045,7 +1045,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
for image in images: for image in images:
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
prompt = ( prompt = (

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import IO, ComfyExtension
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import torch import torch
@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
IdeogramV3Request, IdeogramV3Request,
IdeogramV3EditRequest, IdeogramV3EditRequest,
) )
from comfy_api_nodes.util import (
from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
bytesio_to_image_tensor, bytesio_to_image_tensor,
download_url_as_bytesio,
resize_mask_to_image, resize_mask_to_image,
sync_op,
) )
from server import PromptServer
V1_V1_RES_MAP = { V1_V1_RES_MAP = {
"Auto":"AUTO", "Auto":"AUTO",
@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
for image_url in image_urls: for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing # Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor) image_tensors.append(img_tensor)
@ -233,19 +227,6 @@ async def download_and_process_images(image_urls):
return stacked_tensors return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(IO.ComfyNode): class IdeogramV1(IO.ComfyNode):
@classmethod @classmethod
@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1" model = "V_1_TURBO" if turbo else "V_1"
auth = { response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
} response_model=IdeogramGenerateResponse,
operation = SynchronousOperation( data=IdeogramGenerateRequest(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest( image_request=ImageRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
num_images=num_images, num_images=num_images,
seed=seed, seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=( magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
) )
), ),
auth_kwargs=auth, max_retries=1,
) )
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode):
else: else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
auth = { response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
} response_model=IdeogramGenerateResponse,
operation = SynchronousOperation( data=IdeogramGenerateRequest(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest( image_request=ImageRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode):
seed=seed, seed=seed,
aspect_ratio=final_aspect_ratio, aspect_ratio=final_aspect_ratio,
resolution=final_resolution, resolution=final_resolution,
magic_prompt_option=( magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
style_type=style_type if style_type != "NONE" else None, style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None, color_palette=color_palette if color_palette else None,
) )
), ),
auth_kwargs=auth, max_retries=1,
) )
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode):
character_image=None, character_image=None,
character_mask=None, character_mask=None,
): ):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT" rendering_speed = "DEFAULT"
@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode):
# Check if both image and mask are provided for editing mode # Check if both image and mask are provided for editing mode
if image is not None and mask is not None: if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask # Process image and mask
input_tensor = image.squeeze().cpu() input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension # Resize mask to match image dimension
@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode):
if character_mask_binary: if character_mask_binary:
files["character_mask_binary"] = character_mask_binary files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode response = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
path=path, response_model=IdeogramGenerateResponse,
method=HttpMethod.POST, data=edit_request,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth, max_retries=1,
) )
elif image is not None or mask is not None: elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error # If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask") raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else: else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request # Create generation request
gen_request = IdeogramV3Request( gen_request = IdeogramV3Request(
prompt=prompt, prompt=prompt,
@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode):
if files: if files:
gen_request.style_type = "AUTO" gen_request.style_type = "AUTO"
# Execute the operation for generation mode response = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
path=path, response_model=IdeogramGenerateResponse,
method=HttpMethod.POST, data=gen_request,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
files=files if files else None, files=files if files else None,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth, max_retries=1,
) )
# Execute the operation and process response
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension):
IdeogramV3, IdeogramV3,
] ]
async def comfy_entrypoint() -> IdeogramExtension: async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension() return IdeogramExtension()

View File

@ -282,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
""" """
validate_image_dimensions(image, min_width=300, min_height=300) validate_image_dimensions(image, min_width=300, min_height=300)
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
def get_video_from_response(response) -> KlingVideoResult: def get_video_from_response(response) -> KlingVideoResult:

View File

@ -1,69 +1,51 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional from typing import Optional
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.luma_api import ( from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
LumaVideoModel,
LumaVideoOutputResolution,
LumaVideoModelOutputDuration,
LumaAspectRatio, LumaAspectRatio,
LumaState,
LumaImageGenerationRequest,
LumaGenerationRequest,
LumaGeneration,
LumaCharacterRef, LumaCharacterRef,
LumaModifyImageRef, LumaConceptChain,
LumaGeneration,
LumaGenerationRequest,
LumaImageGenerationRequest,
LumaImageIdentity, LumaImageIdentity,
LumaImageModel,
LumaImageReference,
LumaIO,
LumaKeyframes,
LumaModifyImageRef,
LumaReference, LumaReference,
LumaReferenceChain, LumaReferenceChain,
LumaImageReference, LumaVideoModel,
LumaKeyframes, LumaVideoModelOutputDuration,
LumaConceptChain, LumaVideoOutputResolution,
LumaIO,
get_luma_concepts, get_luma_concepts,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_image_tensor,
SynchronousOperation, download_url_to_video_output,
PollingOperation, poll_op,
EmptyRequest, sync_op,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi, upload_images_to_comfyapi,
process_image_response, validate_string,
) )
from server import PromptServer
from comfy_api_nodes.util import validate_string
import aiohttp
import torch
from io import BytesIO
LUMA_T2V_AVERAGE_DURATION = 105 LUMA_T2V_AVERAGE_DURATION = 105
LUMA_I2V_AVERAGE_DURATION = 100 LUMA_I2V_AVERAGE_DURATION = 100
def image_result_url_extractor(response: LumaGeneration):
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(IO.ComfyNode): class LumaReferenceNode(IO.ComfyNode):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaReferenceNode", node_id="LumaReferenceNode",
display_name="Luma Reference", display_name="Luma Reference",
category="api node/image/Luma", category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""), description="Holds an image and weight for use with Luma Generate Image node.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
"image", "image",
@ -83,17 +65,10 @@ class LumaReferenceNode(IO.ComfyNode):
), ),
], ],
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")], outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
) )
@classmethod @classmethod
def execute( def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> IO.NodeOutput:
if luma_ref is not None: if luma_ref is not None:
luma_ref = luma_ref.clone() luma_ref = luma_ref.clone()
else: else:
@ -103,17 +78,13 @@ class LumaReferenceNode(IO.ComfyNode):
class LumaConceptsNode(IO.ComfyNode): class LumaConceptsNode(IO.ComfyNode):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaConceptsNode", node_id="LumaConceptsNode",
display_name="Luma Concepts", display_name="Luma Concepts",
category="api node/video/Luma", category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""), description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"concept1", "concept1",
@ -138,11 +109,6 @@ class LumaConceptsNode(IO.ComfyNode):
), ),
], ],
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")], outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
) )
@classmethod @classmethod
@ -161,17 +127,13 @@ class LumaConceptsNode(IO.ComfyNode):
class LumaImageGenerationNode(IO.ComfyNode): class LumaImageGenerationNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaImageNode", node_id="LumaImageNode",
display_name="Luma Text to Image", display_name="Luma Text to Image",
category="api node/image/Luma", category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""), description="Generates images synchronously based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -237,45 +199,30 @@ class LumaImageGenerationNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
seed, seed,
style_image_weight: float, style_image_weight: float,
image_luma_ref: LumaReferenceChain = None, image_luma_ref: Optional[LumaReferenceChain] = None,
style_image: torch.Tensor = None, style_image: Optional[torch.Tensor] = None,
character_image: torch.Tensor = None, character_image: Optional[torch.Tensor] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3) validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref # handle image_luma_ref
api_image_ref = None api_image_ref = None
if image_luma_ref is not None: if image_luma_ref is not None:
api_image_ref = await cls._convert_luma_refs( api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
)
# handle style_luma_ref # handle style_luma_ref
api_style_ref = None api_style_ref = None
if style_image is not None: if style_image is not None:
api_style_ref = await cls._convert_style_image( api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
)
# handle character_ref images # handle character_ref images
character_ref = None character_ref = None
if character_image is not None: if character_image is not None:
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
character_image, max_images=4, auth_kwargs=auth_kwargs, character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
)
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/luma/generations/image", ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
method=HttpMethod.POST, response_model=LumaGeneration,
request_model=LumaImageGenerationRequest, data=LumaImageGenerationRequest(
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
@ -283,41 +230,21 @@ class LumaImageGenerationNode(IO.ComfyNode):
style_ref=api_style_ref, style_ref=api_style_ref,
character_ref=character_ref, character_ref=character_ref,
), ),
auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_poll = await poll_op(
cls,
operation = PollingOperation( ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
poll_endpoint=ApiEndpoint( response_model=LumaGeneration,
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
@classmethod @classmethod
async def _convert_luma_refs( async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = [] luma_urls = []
ref_count = 0 ref_count = 0
for ref in luma_ref.refs: for ref in luma_ref.refs:
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
ref.image, max_images=1, auth_kwargs=auth_kwargs
)
luma_urls.append(download_urls[0]) luma_urls.append(download_urls[0])
ref_count += 1 ref_count += 1
if ref_count >= max_refs: if ref_count >= max_refs:
@ -325,27 +252,19 @@ class LumaImageGenerationNode(IO.ComfyNode):
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod @classmethod
async def _convert_style_image( async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
): return await cls._convert_luma_refs(chain, max_refs=1)
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(IO.ComfyNode): class LumaImageModifyNode(IO.ComfyNode):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaImageModifyNode", node_id="LumaImageModifyNode",
display_name="Luma Image to Image", display_name="Luma Image to Image",
category="api node/image/Luma", category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""), description="Modifies images synchronously based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
"image", "image",
@ -395,68 +314,37 @@ class LumaImageModifyNode(IO.ComfyNode):
image_weight: float, image_weight: float,
seed, seed,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth_kwargs = { download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth_kwargs,
)
image_url = download_urls[0] image_url = download_urls[0]
# next, make Luma call with download url provided response_api = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
path="/proxy/luma/generations/image", response_model=LumaGeneration,
method=HttpMethod.POST, data=LumaImageGenerationRequest(
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
modify_image_ref=LumaModifyImageRef( modify_image_ref=LumaModifyImageRef(
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
), ),
), ),
auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_poll = await poll_op(
cls,
operation = PollingOperation( ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
poll_endpoint=ApiEndpoint( response_model=LumaGeneration,
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
class LumaTextToVideoGenerationNode(IO.ComfyNode): class LumaTextToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaVideoNode", node_id="LumaVideoNode",
display_name="Luma Text to Video", display_name="Luma Text to Video",
category="api node/video/Luma", category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on prompt and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -498,7 +386,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
"luma_concepts", "luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True, optional=True,
) ),
], ],
outputs=[IO.Video.Output()], outputs=[IO.Video.Output()],
hidden=[ hidden=[
@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
duration: str, duration: str,
loop: bool, loop: bool,
seed, seed,
luma_concepts: LumaConceptChain = None, luma_concepts: Optional[LumaConceptChain] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3) validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/luma/generations", method="POST"),
} response_model=LumaGeneration,
operation = SynchronousOperation( data=LumaGenerationRequest(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
resolution=resolution, resolution=resolution,
@ -545,47 +426,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
loop=loop, loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None, concepts=luma_concepts.create_api_model() if luma_concepts else None,
), ),
auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_poll = await poll_op(
cls,
if cls.hidden.unique_id: ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) response_model=LumaGeneration,
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION, estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class LumaImageToVideoGenerationNode(IO.ComfyNode): class LumaImageToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaImageToVideoNode", node_id="LumaImageToVideoNode",
display_name="Luma Image to Video", display_name="Luma Image to Video",
category="api node/video/Luma", category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on prompt, input images, and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -637,7 +496,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
"luma_concepts", "luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True, optional=True,
) ),
], ],
outputs=[IO.Video.Output()], outputs=[IO.Video.Output()],
hidden=[ hidden=[
@ -662,25 +521,15 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
luma_concepts: LumaConceptChain = None, luma_concepts: LumaConceptChain = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if first_image is None and last_image is None: if first_image is None and last_image is None:
raise Exception( raise Exception("At least one of first_image and last_image requires an input.")
"At least one of first_image and last_image requires an input." keyframes = await cls._convert_to_keyframes(first_image, last_image)
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None
response_api = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/luma/generations", method="POST"),
path="/proxy/luma/generations", response_model=LumaGeneration,
method=HttpMethod.POST, data=LumaGenerationRequest(
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
@ -690,54 +539,31 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
keyframes=keyframes, keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None, concepts=luma_concepts.create_api_model() if luma_concepts else None,
), ),
auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_poll = await poll_op(
cls,
if cls.hidden.unique_id: poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) response_model=LumaGeneration,
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION, estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
@classmethod @classmethod
async def _convert_to_keyframes( async def _convert_to_keyframes(
cls, cls,
first_image: torch.Tensor = None, first_image: torch.Tensor = None,
last_image: torch.Tensor = None, last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
): ):
if first_image is None and last_image is None: if first_image is None and last_image is None:
return None return None
frame0 = None frame0 = None
frame1 = None frame1 = None
if first_image is not None: if first_image is not None:
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
first_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame0 = LumaImageReference(type="image", url=download_urls[0]) frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None: if last_image is not None:
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
last_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame1 = LumaImageReference(type="image", url=download_urls[0]) frame1 = LumaImageReference(type="image", url=download_urls[0])
return LumaKeyframes(frame0=frame0, frame1=frame1) return LumaKeyframes(frame0=frame0, frame1=frame1)

View File

@ -1,71 +1,57 @@
from inspect import cleandoc
from typing import Optional from typing import Optional
import logging
import torch
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis.minimax_api import (
MinimaxFileRetrieveResponse,
MiniMaxModel,
MinimaxTaskResultResponse,
MinimaxVideoGenerationRequest, MinimaxVideoGenerationRequest,
MinimaxVideoGenerationResponse, MinimaxVideoGenerationResponse,
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem, SubjectReferenceItem,
MiniMaxModel,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_video_output,
SynchronousOperation, poll_op,
PollingOperation, sync_op,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_string,
) )
from comfy_api_nodes.util import validate_string
from server import PromptServer
I2V_AVERAGE_DURATION = 114 I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234 T2V_AVERAGE_DURATION = 234
async def _generate_mm_video( async def _generate_mm_video(
cls: type[IO.ComfyNode],
*, *,
auth: dict[str, str],
node_id: str,
prompt_text: str, prompt_text: str,
seed: int, seed: int,
model: str, model: str,
image: Optional[torch.Tensor] = None, # used for ImageToVideo image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
average_duration: Optional[int] = None, average_duration: Optional[int] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image is None: if image is None:
validate_string(prompt_text, field_name="prompt_text") validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None image_url = None
if image is not None: if image is not None:
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0] image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None subject_reference = None
if subject is not None: if subject is not None:
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0] subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)] subject_reference = [SubjectReferenceItem(image=subject_url)]
response = await sync_op(
video_generate_operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
path="/proxy/minimax/video_generation", response_model=MinimaxVideoGenerationResponse,
method=HttpMethod.POST, data=MinimaxVideoGenerationRequest(
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model), model=MiniMaxModel(model),
prompt=prompt_text, prompt=prompt_text,
callback_url=None, callback_url=None,
@ -73,81 +59,50 @@ async def _generate_mm_video(
subject_reference=subject_reference, subject_reference=subject_reference,
prompt_optimizer=None, prompt_optimizer=None,
), ),
auth_kwargs=auth,
) )
response = await video_generate_operation.execute()
task_id = response.task_id task_id = response.task_id
if not task_id: if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}") raise Exception(f"MiniMax generation failed: {response.base_resp}")
video_generate_operation = PollingOperation( task_result = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path="/proxy/minimax/query/video_generation", ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
method=HttpMethod.GET, response_model=MinimaxTaskResultResponse,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value, status_extractor=lambda x: x.status.value,
estimated_duration=average_duration, estimated_duration=average_duration,
node_id=node_id,
auth_kwargs=auth,
) )
task_result = await video_generate_operation.execute()
file_id = task_result.file_id file_id = task_result.file_id
if file_id is None: if file_id is None:
raise Exception("Request was not successful. Missing file ID.") raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation( file_result = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/minimax/files/retrieve", ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
method=HttpMethod.GET, response_model=MinimaxFileRetrieveResponse,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
) )
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url file_url = file_result.file.download_url
if file_url is None: if file_url is None:
raise Exception( raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
f"No video was found in the response. Full response: {file_result.model_dump()}" if file_result.file.backup_download_url:
) try:
logging.info("Generated video URL: %s", file_url) return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
if node_id: except Exception: # if we have a second URL to retrieve the result, try again using that one
if hasattr(file_result.file, "backup_download_url"): return IO.NodeOutput(
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
else: )
message = f"Result URL: {file_url}" return IO.NodeOutput(await download_url_to_video_output(file_url))
PromptServer.instance.send_progress_text(message, node_id)
# Download and return as VideoFromFile
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return IO.NodeOutput(VideoFromFile(video_io))
class MinimaxTextToVideoNode(IO.ComfyNode): class MinimaxTextToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="MinimaxTextToVideoNode", node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video", display_name="MiniMax Text to Video",
category="api node/video/MiniMax", category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt_text", "prompt_text",
@ -189,11 +144,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
seed: int = 0, seed: int = 0,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await _generate_mm_video( return await _generate_mm_video(
auth={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text, prompt_text=prompt_text,
seed=seed, seed=seed,
model=model, model=model,
@ -204,17 +155,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
class MinimaxImageToVideoNode(IO.ComfyNode): class MinimaxImageToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="MinimaxImageToVideoNode", node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video", display_name="MiniMax Image to Video",
category="api node/video/MiniMax", category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
"image", "image",
@ -261,11 +208,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
seed: int = 0, seed: int = 0,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await _generate_mm_video( return await _generate_mm_video(
auth={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text, prompt_text=prompt_text,
seed=seed, seed=seed,
model=model, model=model,
@ -276,17 +219,13 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
class MinimaxSubjectToVideoNode(IO.ComfyNode): class MinimaxSubjectToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="MinimaxSubjectToVideoNode", node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video", display_name="MiniMax Subject to Video",
category="api node/video/MiniMax", category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
"subject", "subject",
@ -333,11 +272,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
seed: int = 0, seed: int = 0,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await _generate_mm_video( return await _generate_mm_video(
auth={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text, prompt_text=prompt_text,
seed=seed, seed=seed,
model=model, model=model,
@ -348,15 +283,13 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
class MinimaxHailuoVideoNode(IO.ComfyNode): class MinimaxHailuoVideoNode(IO.ComfyNode):
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="MinimaxHailuoVideoNode", node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video", display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax", category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""), description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt_text", "prompt_text",
@ -420,10 +353,6 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
resolution: str = "768P", resolution: str = "768P",
model: str = "MiniMax-Hailuo-02", model: str = "MiniMax-Hailuo-02",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if first_frame_image is None: if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text") validate_string(prompt_text, field_name="prompt_text")
@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
# upload image, if passed in # upload image, if passed in
image_url = None image_url = None
if first_frame_image is not None: if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0] image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
video_generate_operation = SynchronousOperation( response = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/minimax/video_generation", ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
method=HttpMethod.POST, response_model=MinimaxVideoGenerationResponse,
request_model=MinimaxVideoGenerationRequest, data=MinimaxVideoGenerationRequest(
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model), model=MiniMaxModel(model),
prompt=prompt_text, prompt=prompt_text,
callback_url=None, callback_url=None,
@ -453,67 +379,42 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
duration=duration, duration=duration,
resolution=resolution, resolution=resolution,
), ),
auth_kwargs=auth,
) )
response = await video_generate_operation.execute()
task_id = response.task_id task_id = response.task_id
if not task_id: if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}") raise Exception(f"MiniMax generation failed: {response.base_resp}")
average_duration = 120 if resolution == "768P" else 240 average_duration = 120 if resolution == "768P" else 240
video_generate_operation = PollingOperation( task_result = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path="/proxy/minimax/query/video_generation", ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
method=HttpMethod.GET, response_model=MinimaxTaskResultResponse,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value, status_extractor=lambda x: x.status.value,
estimated_duration=average_duration, estimated_duration=average_duration,
node_id=cls.hidden.unique_id,
auth_kwargs=auth,
) )
task_result = await video_generate_operation.execute()
file_id = task_result.file_id file_id = task_result.file_id
if file_id is None: if file_id is None:
raise Exception("Request was not successful. Missing file ID.") raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation( file_result = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/minimax/files/retrieve", ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
method=HttpMethod.GET, response_model=MinimaxFileRetrieveResponse,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
) )
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url file_url = file_result.file.download_url
if file_url is None: if file_url is None:
raise Exception( raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
video_io = await download_url_to_bytesio(file_url) if file_result.file.backup_download_url:
if video_io is None: try:
error_msg = f"Failed to download video from {file_url}" return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
logging.error(error_msg) except Exception: # if we have a second URL to retrieve the result, try again using that one
raise Exception(error_msg) return IO.NodeOutput(
return IO.NodeOutput(VideoFromFile(video_io)) await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxExtension(ComfyExtension): class MinimaxExtension(ComfyExtension):

View File

@ -225,7 +225,7 @@ class OpenAIDalle2(ComfyNodeABC):
), ),
files=( files=(
{ {
"image": img_binary, "image": ("image.png", img_binary, "image/png"),
} }
if img_binary if img_binary
else None else None

View File

@ -1,7 +1,6 @@
from inspect import cleandoc import torch
from typing import Optional
from typing_extensions import override from typing_extensions import override
from io import BytesIO from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.pixverse_api import ( from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest, PixverseTextVideoRequest,
PixverseImageVideoRequest, PixverseImageVideoRequest,
@ -17,53 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import (
PixverseIO, PixverseIO,
pixverse_templates, pixverse_templates,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_video_output,
SynchronousOperation, poll_op,
PollingOperation, sync_op,
EmptyRequest, tensor_to_bytesio,
validate_string,
) )
from comfy_api_nodes.util import validate_string, tensor_to_bytesio
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
import torch
import aiohttp
AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30 AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52 AVERAGE_DURATION_T2T = 52
def get_video_url_from_response( async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
response: PixverseGenerationStatusResponse, response_upload = await sync_op(
) -> Optional[str]: cls,
if response.Resp is None or response.Resp.url is None: ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
return None response_model=PixverseImageUploadResponse,
return str(response.Resp.url)
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=PixverseImageUploadResponse,
),
request=EmptyRequest(),
files={"image": tensor_to_bytesio(image)}, files={"image": tensor_to_bytesio(image)},
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
) )
response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None: if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
return response_upload.Resp.img_id return response_upload.Resp.img_id
@ -93,17 +69,13 @@ class PixverseTemplateNode(IO.ComfyNode):
class PixverseTextToVideoNode(IO.ComfyNode): class PixverseTextToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseTextToVideoNode", node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video", display_name="PixVerse Text to Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -170,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
negative_prompt: str = None, negative_prompt: str = None,
pixverse_template: int = None, pixverse_template: int = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False, min_length=1)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p: if quality == PixverseQuality.res_1080p:
@ -179,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
} response_model=PixverseVideoResponse,
operation = SynchronousOperation( data=PixverseTextVideoRequest(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
method=HttpMethod.POST,
request_model=PixverseTextVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTextVideoRequest(
prompt=prompt, prompt=prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
quality=quality, quality=quality,
@ -200,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -221,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseImageToVideoNode(IO.ComfyNode): class PixverseImageToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseImageToVideoNode", node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video", display_name="PixVerse Image to Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.String.Input( IO.String.Input(
@ -309,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
pixverse_template: int = None, pixverse_template: int = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
auth = { img_id = await upload_image_to_pixverse(cls, image)
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -323,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/pixverse/video/img/generate", ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
method=HttpMethod.POST, response_model=PixverseVideoResponse,
request_model=PixverseImageVideoRequest, data=PixverseImageVideoRequest(
response_model=PixverseVideoResponse,
),
request=PixverseImageVideoRequest(
img_id=img_id, img_id=img_id,
prompt=prompt, prompt=prompt,
quality=quality, quality=quality,
@ -340,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -361,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V, estimated_duration=AVERAGE_DURATION_I2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseTransitionVideoNode(IO.ComfyNode): class PixverseTransitionVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseTransitionVideoNode", node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video", display_name="PixVerse Transition Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("first_frame"), IO.Image.Input("first_frame"),
IO.Image.Input("last_frame"), IO.Image.Input("last_frame"),
@ -445,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt: str = None, negative_prompt: str = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
auth = { first_frame_id = await upload_image_to_pixverse(cls, first_frame)
"auth_token": cls.hidden.auth_token_comfy_org, last_frame_id = await upload_image_to_pixverse(cls, last_frame)
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -460,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/pixverse/video/transition/generate", ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
method=HttpMethod.POST, response_model=PixverseVideoResponse,
request_model=PixverseTransitionVideoRequest, data=PixverseTransitionVideoRequest(
response_model=PixverseVideoResponse,
),
request=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id, first_frame_img=first_frame_id,
last_frame_img=last_frame_id, last_frame_img=last_frame_id,
prompt=prompt, prompt=prompt,
@ -477,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -498,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixVerseExtension(ComfyExtension): class PixVerseExtension(ComfyExtension):

File diff suppressed because it is too large Load Diff

View File

@ -200,7 +200,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
@ -290,7 +290,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
@ -390,8 +390,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999) validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
@ -475,7 +475,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
reference_images = None reference_images = None
if reference_image is not None: if reference_image is not None:
validate_image_dimensions(reference_image, max_width=7999, max_height=7999) validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
reference_image, reference_image,

View File

@ -14,9 +14,9 @@ from comfy_api_nodes.util import (
poll_op, poll_op,
sync_op, sync_op,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_aspect_ratio_closeness, validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions, validate_image_dimensions,
validate_images_aspect_ratio_closeness,
) )
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
@ -114,7 +114,7 @@ async def execute_task(
cls, cls,
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: r.state.value, status_extractor=lambda r: r.state,
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
) )
@ -307,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if get_number_of_images(image) > 1: if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.") raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
prompt=prompt, prompt=prompt,
@ -423,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
if a > 7: if a > 7:
raise ValueError("Too many images, maximum allowed is 7.") raise ValueError("Too many images, maximum allowed is 7.")
for image in images: for image in images:
validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128) validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
@ -533,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
resolution: str, resolution: str,
movement_amplitude: str, movement_amplitude: str,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
prompt=prompt, prompt=prompt,

View File

@ -14,6 +14,7 @@ from .conversions import (
downscale_image_tensor, downscale_image_tensor,
image_tensor_pair_to_batch, image_tensor_pair_to_batch,
pil_to_bytesio, pil_to_bytesio,
resize_mask_to_image,
tensor_to_base64_string, tensor_to_base64_string,
tensor_to_bytesio, tensor_to_bytesio,
tensor_to_pil, tensor_to_pil,
@ -34,12 +35,12 @@ from .upload_helpers import (
) )
from .validation_utils import ( from .validation_utils import (
get_number_of_images, get_number_of_images,
validate_aspect_ratio_closeness, validate_aspect_ratio_string,
validate_audio_duration, validate_audio_duration,
validate_container_format_is_mp4, validate_container_format_is_mp4,
validate_image_aspect_ratio, validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions, validate_image_dimensions,
validate_images_aspect_ratio_closeness,
validate_string, validate_string,
validate_video_dimensions, validate_video_dimensions,
validate_video_duration, validate_video_duration,
@ -70,6 +71,7 @@ __all__ = [
"downscale_image_tensor", "downscale_image_tensor",
"image_tensor_pair_to_batch", "image_tensor_pair_to_batch",
"pil_to_bytesio", "pil_to_bytesio",
"resize_mask_to_image",
"tensor_to_base64_string", "tensor_to_base64_string",
"tensor_to_bytesio", "tensor_to_bytesio",
"tensor_to_pil", "tensor_to_pil",
@ -77,12 +79,12 @@ __all__ = [
"video_to_base64_string", "video_to_base64_string",
# Validation utilities # Validation utilities
"get_number_of_images", "get_number_of_images",
"validate_aspect_ratio_closeness", "validate_aspect_ratio_string",
"validate_audio_duration", "validate_audio_duration",
"validate_container_format_is_mp4", "validate_container_format_is_mp4",
"validate_image_aspect_ratio", "validate_image_aspect_ratio",
"validate_image_aspect_ratio_range",
"validate_image_dimensions", "validate_image_dimensions",
"validate_images_aspect_ratio_closeness",
"validate_string", "validate_string",
"validate_video_dimensions", "validate_video_dimensions",
"validate_video_duration", "validate_video_duration",

View File

@ -78,7 +78,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 429, 500, 502, 503, 504} _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"] FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]

View File

@ -430,3 +430,24 @@ def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
wav = torch.cat(frames, dim=1) # [C, T] wav = torch.cat(frames, dim=1) # [C, T]
wav = _f32_pcm(wav) wav = _f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
_, height, width, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask

View File

@ -232,11 +232,12 @@ async def download_url_to_video_output(
video_url: str, video_url: str,
*, *,
timeout: float = None, timeout: float = None,
max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None, cls: type[COMFY_IO.ComfyNode] = None,
) -> VideoFromFile: ) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.""" """Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO() result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls) await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
return VideoFromFile(result) return VideoFromFile(result)

View File

@ -37,63 +37,62 @@ def validate_image_dimensions(
def validate_image_aspect_ratio( def validate_image_aspect_ratio(
image: torch.Tensor, image: torch.Tensor,
min_aspect_ratio: Optional[float] = None, min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_aspect_ratio: Optional[float] = None, max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
def validate_image_aspect_ratio_range(
image: torch.Tensor,
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
*, *,
strict: bool = True, # True -> (min, max); False -> [min, max] strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float: ) -> float:
a1, b1 = min_ratio """Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
a2, b2 = max_ratio
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
lo, hi = (a1 / b1), (a2 / b2)
if lo > hi:
lo, hi = hi, lo
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
w, h = get_image_dimensions(image) w, h = get_image_dimensions(image)
if w <= 0 or h <= 0: if w <= 0 or h <= 0:
raise ValueError(f"Invalid image dimensions: {w}x{h}") raise ValueError(f"Invalid image dimensions: {w}x{h}")
ar = w / h ar = w / h
ok = (lo < ar < hi) if strict else (lo <= ar <= hi) _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
if not ok:
op = "<" if strict else ""
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
return ar return ar
def validate_aspect_ratio_closeness( def validate_images_aspect_ratio_closeness(
start_img, first_image: torch.Tensor,
end_img, second_image: torch.Tensor,
min_rel: float, min_rel: float, # e.g. 0.8
max_rel: float, max_rel: float, # e.g. 1.25
*, *,
strict: bool = False, # True => exclusive, False => inclusive strict: bool = False, # True -> (min, max); False -> [min, max]
) -> None: ) -> float:
w1, h1 = get_image_dimensions(start_img) """
w2, h2 = get_image_dimensions(end_img) Validates that the two images' aspect ratios are 'close'.
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
Returns the computed closeness factor C.
"""
w1, h1 = get_image_dimensions(first_image)
w2, h2 = get_image_dimensions(second_image)
if min(w1, h1, w2, h2) <= 0: if min(w1, h1, w2, h2) <= 0:
raise ValueError("Invalid image dimensions") raise ValueError("Invalid image dimensions")
ar1 = w1 / h1 ar1 = w1 / h1
ar2 = w2 / h2 ar2 = w2 / h2
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
closeness = max(ar1, ar2) / min(ar1, ar2) closeness = max(ar1, ar2) / min(ar1, ar2)
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25 limit = max(max_rel, 1.0 / min_rel)
if (closeness >= limit) if strict else (closeness > limit): if (closeness >= limit) if strict else (closeness > limit):
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}{max_rel}.") raise ValueError(
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
f"allowed range {min_rel}{max_rel} (limit {limit:.2g})."
)
return closeness
def validate_aspect_ratio_string(
aspect_ratio: str,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
ar = _parse_aspect_ratio_string(aspect_ratio)
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
return ar
def validate_video_dimensions( def validate_video_dimensions(
@ -183,3 +182,49 @@ def validate_container_format_is_mp4(video: VideoInput) -> None:
container_format = video.get_container_format() container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}") raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
def _ratio_from_tuple(r: tuple[float, float]) -> float:
a, b = r
if a <= 0 or b <= 0:
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
return a / b
def _assert_ratio_bounds(
ar: float,
*,
min_ratio: Optional[tuple[float, float]] = None,
max_ratio: Optional[tuple[float, float]] = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
if lo is not None and hi is not None and lo > hi:
lo, hi = hi, lo # normalize order if caller swapped them
if lo is not None:
if (ar <= lo) if strict else (ar < lo):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
if hi is not None:
if (ar >= hi) if strict else (ar > hi):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
def _parse_aspect_ratio_string(ar_str: str) -> float:
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
parts = ar_str.split(":")
if len(parts) != 2:
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
try:
a = int(parts[0].strip())
b = int(parts[1].strip())
except ValueError as exc:
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
if a <= 0 or b <= 0:
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
return a / b

View File

@ -1,4 +1,9 @@
import bisect
import gc
import itertools import itertools
import psutil
import time
import torch
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -188,6 +193,9 @@ class BasicCache:
self._clean_cache() self._clean_cache()
self._clean_subcaches() self._clean_subcaches()
def poll(self, **kwargs):
pass
def _set_immediate(self, node_id, value): def _set_immediate(self, node_id, value):
assert self.initialized assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
@ -276,6 +284,9 @@ class NullCache:
def clean_unused(self): def clean_unused(self):
pass pass
def poll(self, **kwargs):
pass
def get(self, node_id): def get(self, node_id):
return None return None
@ -336,3 +347,75 @@ class LRUCache(BasicCache):
self._mark_used(child_id) self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)
def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
def scan_list_for_ram_usage(outputs):
nonlocal ram_usage
for output in outputs:
if isinstance(output, list):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()

View File

@ -209,10 +209,15 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id) self.execution_cache_listeners[from_node_id].add(to_node_id)
def get_output_cache(self, from_node_id, to_node_id): def get_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache: if not to_node_id in self.execution_cache:
return None return None
return self.execution_cache[to_node_id].get(from_node_id) value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value
def cache_update(self, node_id, value): def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners: if node_id in self.execution_cache_listeners:

View File

@ -0,0 +1,47 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class ScaleROPE(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ScaleROPE",
category="advanced/model_patches",
description="Scale and shift the ROPE of the model.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
m = model.clone()
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
return io.NodeOutput(m)
class RopeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ScaleROPE
]
async def comfy_entrypoint() -> RopeExtension:
return RopeExtension()

View File

@ -21,6 +21,7 @@ from comfy_execution.caching import (
NullCache, NullCache,
HierarchicalCache, HierarchicalCache,
LRUCache, LRUCache,
RAMPressureCache,
) )
from comfy_execution.graph import ( from comfy_execution.graph import (
DynamicPrompt, DynamicPrompt,
@ -88,49 +89,56 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
class CacheEntry(NamedTuple):
ui: dict
outputs: list
class CacheType(Enum): class CacheType(Enum):
CLASSIC = 0 CLASSIC = 0
LRU = 1 LRU = 1
NONE = 2 NONE = 2
RAM_PRESSURE = 3
class CacheSet: class CacheSet:
def __init__(self, cache_type=None, cache_size=None): def __init__(self, cache_type=None, cache_args={}):
if cache_type == CacheType.NONE: if cache_type == CacheType.NONE:
self.init_null_cache() self.init_null_cache()
logging.info("Disabling intermediate node cache.") logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU: elif cache_type == CacheType.LRU:
if cache_size is None: cache_size = cache_args.get("lru", 0)
cache_size = 0
self.init_lru_cache(cache_size) self.init_lru_cache(cache_size)
logging.info("Using LRU cache") logging.info("Using LRU cache")
else: else:
self.init_classic_cache() self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects] self.all = [self.outputs, self.objects]
# Performs like the old cache -- dump data ASAP # Performs like the old cache -- dump data ASAP
def init_classic_cache(self): def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size): def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self): def init_null_cache(self):
self.outputs = NullCache() self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache() self.objects = NullCache()
def recursive_debug_dump(self): def recursive_debug_dump(self):
result = { result = {
"outputs": self.outputs.recursive_debug_dump(), "outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
} }
return result return result
@ -157,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if execution_list is None: if execution_list is None:
mark_missing() mark_missing()
continue # This might be a lazily-evaluated input continue # This might be a lazily-evaluated input
cached_output = execution_list.get_output_cache(input_unique_id, unique_id) cached = execution_list.get_cache(input_unique_id, unique_id)
if cached_output is None: if cached is None or cached.outputs is None:
mark_missing() mark_missing()
continue continue
if output_index >= len(cached_output): if output_index >= len(cached.outputs):
mark_missing() mark_missing()
continue continue
obj = cached_output[output_index] obj = cached.outputs[output_index]
input_data_all[x] = obj input_data_all[x] = obj
elif input_category is not None: elif input_category is not None:
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
@ -393,7 +401,7 @@ def format_value(x):
else: else:
return str(x) return str(x)
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes): async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id)
@ -401,12 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None: cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None: if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {} cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id) get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
input_data_all = None input_data_all = None
@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result: for r in result:
if is_link(r): if is_link(r):
source_node, source_output = r[0], r[1] source_node, source_output = r[0], r[1]
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] node_cached = execution_list.get_cache(source_node, unique_id)
for o in node_output: for o in node_cached.outputs[source_output]:
resolved_output.append(o) resolved_output.append(o)
else: else:
@ -445,6 +456,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
resolved_outputs.append(tuple(resolved_output)) resolved_outputs.append(tuple(resolved_output))
output_data = merge_result_data(resolved_outputs, class_def) output_data = merge_result_data(resolved_outputs, class_def)
output_ui = [] output_ui = []
del pending_subgraph_results[unique_id]
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id) get_progress_state().start_progress(unique_id)
@ -506,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
asyncio.create_task(await_completion()) asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
caches.ui.set(unique_id, { ui_outputs[unique_id] = {
"meta": { "meta": {
"node_id": unique_id, "node_id": unique_id,
"display_node": display_node_id, "display_node": display_node_id,
@ -514,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"real_node_id": real_node_id, "real_node_id": real_node_id,
}, },
"output": output_ui "output": output_ui
}) }
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph: if has_subgraph:
@ -527,10 +539,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if new_graph is None: if new_graph is None:
cached_outputs.append((False, node_outputs)) cached_outputs.append((False, node_outputs))
else: else:
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.has_node(node_id):
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
for node_id, node_info in new_graph.items(): for node_id, node_info in new_graph.items():
new_node_ids.append(node_id) new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id) display_id = node_info.get("override_display_id", unique_id)
@ -557,8 +565,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
pending_subgraph_results[unique_id] = cached_outputs pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data) cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, output_data) execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex: except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
@ -603,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor: class PromptExecutor:
def __init__(self, server, cache_type=False, cache_size=None): def __init__(self, server, cache_type=False, cache_args=None):
self.cache_size = cache_size self.cache_args = cache_args
self.cache_type = cache_type self.cache_type = cache_type
self.server = server self.server = server
self.reset() self.reset()
def reset(self): def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = [] self.status_messages = []
self.success = True self.success = True
@ -685,6 +694,7 @@ class PromptExecutor:
broadcast=False) broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
@ -698,7 +708,7 @@ class PromptExecutor:
break break
assert node_id is not None, "Node ID should not be None at this point" assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE: if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@ -707,18 +717,16 @@ class PromptExecutor:
execution_list.unstage_node_execution() execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS: else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution() execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {} ui_outputs = {}
meta_outputs = {} meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids() for node_id, ui_info in ui_node_outputs.items():
for node_id in all_node_ids: ui_outputs[node_id] = ui_info["output"]
ui_info = self.caches.ui.get(node_id) meta_outputs[node_id] = ui_info["meta"]
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = { self.history_result = {
"outputs": ui_outputs, "outputs": ui_outputs,
"meta": meta_outputs, "meta": meta_outputs,

View File

@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
cache_type = execution.CacheType.CLASSIC cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0: if args.cache_lru > 0:
cache_type = execution.CacheType.LRU cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none: elif args.cache_none:
cache_type = execution.CacheType.NONE cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0

View File

@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py", "nodes_model_patch.py",
"nodes_easycache.py", "nodes_easycache.py",
"nodes_audio_encoder.py", "nodes_audio_encoder.py",
"nodes_rope.py",
] ]
import_failed = [] import_failed = []

View File

@ -0,0 +1,232 @@
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def has_gpu():
return torch.cuda.is_available()
from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy import ops
from comfy.quant_ops import QuantizedTensor
class SimpleModel(torch.nn.Module):
def __init__(self, operations=ops.disable_weight_init):
super().__init__()
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
def forward(self, x):
x = self.layer1(x)
x = torch.nn.functional.relu(x)
x = self.layer2(x)
x = torch.nn.functional.relu(x)
x = self.layer3(x)
return x
class TestMixedPrecisionOps(unittest.TestCase):
def test_all_layers_standard(self):
"""Test that model with no quantization works normally"""
# Configure no quantization
ops.MixedPrecisionOps._layer_quant_config = {}
# Create model
model = SimpleModel(operations=ops.MixedPrecisionOps)
# Initialize weights manually
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16))
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
# Initialize weight_function and bias_function
for layer in [model.layer1, model.layer2, model.layer3]:
layer.weight_function = []
layer.bias_function = []
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
self.assertEqual(output.dtype, torch.bfloat16)
def test_mixed_precision_load(self):
"""Test loading a mixed precision model from state dict"""
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
layer_quant_config = {
"layer1": {
"format": "float8_e4m3fn",
"params": {}
},
"layer3": {
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
# Layer 1: FP8 E4M3FN
"layer1.weight": fp8_weight1,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
# Layer 2: Standard BF16
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
# Layer 3: FP8 E4M3FN
"layer3.weight": fp8_weight3,
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
}
# Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
# Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
# Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
def test_state_dict_quantized_preserved(self):
"""Test that quantized weights are preserved in state_dict()"""
# Configure mixed precision
layer_quant_config = {
"layer1": {
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict1 = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict1, strict=False)
# Save state dict
state_dict2 = model.state_dict()
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
def test_weight_function_compatibility(self):
"""Test that weight_function (LoRA) works with quantized layers"""
# Configure FP8 quantization
layer_quant_config = {
"layer1": {
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA)
# This should trigger dequantization during forward pass
def apply_lora(weight):
lora_delta = torch.randn_like(weight) * 0.01
return weight + lora_delta
model.layer1.weight_function.append(apply_lora)
# Forward pass should work with LoRA (triggers weight_function path)
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
def test_error_handling_unknown_format(self):
"""Test that unknown formats raise error"""
# Configure with unknown format
layer_quant_config = {
"layer1": {
"format": "unknown_format_xyz",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict
state_dict = {
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.MixedPrecisionOps)
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,190 @@
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def has_gpu():
return torch.cuda.is_available()
from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
class TestQuantizedTensor(unittest.TestCase):
"""Test the QuantizedTensor subclass with FP8 layout"""
def test_creation(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
def test_dequantize(self):
"""Test explicit dequantization"""
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
dequantized = qt.dequantize()
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_from_float(self):
"""Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32)
scale = torch.tensor(1.5)
qt = QuantizedTensor.from_float(
float_tensor,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt.shape, (64, 32))
# Verify dequantization gives approximately original values
dequantized = qt.dequantize()
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
self.assertLess(mean_rel_error, 0.1)
class TestGenericUtilities(unittest.TestCase):
"""Test generic utility operations"""
def test_detach(self):
"""Test detach operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Detach should return a new QuantizedTensor
qt_detached = qt.detach()
self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
def test_clone(self):
"""Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Clone should return a new QuantizedTensor
qt_cloned = qt.clone()
self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
# Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata)
@unittest.skipUnless(has_gpu(), "GPU not available")
def test_to_device(self):
"""Test device transfer"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu')
self.assertIsInstance(qt_cpu, QuantizedTensor)
self.assertEqual(qt_cpu.device.type, 'cpu')
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
class TestTensorCoreFP8Layout(unittest.TestCase):
"""Test the TensorCoreFP8Layout implementation"""
def test_quantize(self):
"""Test quantization method"""
float_tensor = torch.randn(32, 64, dtype=torch.float32)
scale = torch.tensor(1.5)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
dtype=torch.float8_e4m3fn
)
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
self.assertEqual(qdata.shape, float_tensor.shape)
self.assertIn('scale', layout_params)
self.assertIn('orig_dtype', layout_params)
self.assertEqual(layout_params['orig_dtype'], torch.float32)
def test_dequantize(self):
"""Test dequantization method"""
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
scale = torch.tensor(1.0)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
dtype=torch.float8_e4m3fn
)
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
# Should approximately match original
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations"""
def test_unsupported_op_dequantizes(self):
"""Test that unsupported operations fall back to dequantization"""
# Set seed for reproducibility
torch.manual_seed(42)
# Create quantized tensor
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0)
a_q = QuantizedTensor.from_float(
a_fp32,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
# Call an operation that doesn't have a registered handler
# For example, torch.abs
result = torch.abs(a_q)
# Should work via fallback (dequantize → abs → return)
self.assertNotIsInstance(result, QuantizedTensor)
expected = torch.abs(a_fp32)
# FP8 introduces quantization error, so use loose tolerance
mean_error = (result - expected).abs().mean()
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
if __name__ == "__main__":
unittest.main()