mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
b88c66bfa1
@ -1,2 +1,3 @@
|
|||||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||||
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -1,2 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||||
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -1,2 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||||
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -150,6 +150,7 @@ class PerformanceFeature(enum.Enum):
|
|||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
|
PinnedMem = "pinned_memory"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
|
|||||||
@ -310,11 +310,13 @@ class ControlLoraOps:
|
|||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
x = torch.nn.functional.linear(input, weight, bias)
|
||||||
|
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -350,12 +352,13 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||||
|
|||||||
@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", False)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
@ -333,6 +333,14 @@ class BaseModel(torch.nn.Module):
|
|||||||
if self.model_config.scaled_fp8 is not None:
|
if self.model_config.scaled_fp8 is not None:
|
||||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||||
|
|
||||||
|
# Save mixed precision metadata
|
||||||
|
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||||
|
metadata = {
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": self.model_config.layer_quant_config
|
||||||
|
}
|
||||||
|
unet_state_dict["_quantization_metadata"] = metadata
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
|
|||||||
@ -6,6 +6,20 @@ import math
|
|||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def detect_layer_quantization(metadata):
|
||||||
|
quant_key = "_quantization_metadata"
|
||||||
|
if metadata is not None and quant_key in metadata:
|
||||||
|
quant_metadata = metadata.pop(quant_key)
|
||||||
|
quant_metadata = json.loads(quant_metadata)
|
||||||
|
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||||
|
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||||
|
return quant_metadata["layers"]
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid quantization metadata format")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
count = 0
|
count = 0
|
||||||
while True:
|
while True:
|
||||||
@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
else:
|
else:
|
||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
|
# Detect per-layer quantization (mixed precision)
|
||||||
|
layer_quant_config = detect_layer_quantization(metadata)
|
||||||
|
if layer_quant_config:
|
||||||
|
model_config.layer_quant_config = layer_quant_config
|
||||||
|
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
|
|||||||
@ -1013,6 +1013,16 @@ if args.async_offload:
|
|||||||
NUM_STREAMS = 2
|
NUM_STREAMS = 2
|
||||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||||
|
|
||||||
|
def current_stream(device):
|
||||||
|
if device is None:
|
||||||
|
return None
|
||||||
|
if is_device_cuda(device):
|
||||||
|
return torch.cuda.current_stream()
|
||||||
|
elif is_device_xpu(device):
|
||||||
|
return torch.xpu.current_stream()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
stream_counters = {}
|
stream_counters = {}
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
@ -1021,21 +1031,17 @@ def get_offload_stream(device):
|
|||||||
|
|
||||||
if device in STREAMS:
|
if device in STREAMS:
|
||||||
ss = STREAMS[device]
|
ss = STREAMS[device]
|
||||||
s = ss[stream_counter]
|
#Sync the oldest stream in the queue with the current
|
||||||
|
ss[stream_counter].wait_stream(current_stream(device))
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
if is_device_cuda(device):
|
|
||||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
|
||||||
elif is_device_xpu(device):
|
|
||||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return ss[stream_counter]
|
||||||
elif is_device_cuda(device):
|
elif is_device_cuda(device):
|
||||||
ss = []
|
ss = []
|
||||||
for k in range(NUM_STREAMS):
|
for k in range(NUM_STREAMS):
|
||||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
elif is_device_xpu(device):
|
elif is_device_xpu(device):
|
||||||
@ -1044,18 +1050,14 @@ def get_offload_stream(device):
|
|||||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def sync_stream(device, stream):
|
def sync_stream(device, stream):
|
||||||
if stream is None:
|
if stream is None or current_stream(device) is None:
|
||||||
return
|
return
|
||||||
if is_device_cuda(device):
|
current_stream(device).wait_stream(stream)
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
|
||||||
elif is_device_xpu(device):
|
|
||||||
torch.xpu.current_stream().wait_stream(stream)
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
@ -1080,6 +1082,36 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
non_blocking = device_supports_non_blocking(device)
|
non_blocking = device_supports_non_blocking(device)
|
||||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
def pin_memory(tensor):
|
||||||
|
if PerformanceFeature.PinnedMem not in args.fast:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_device_cpu(tensor.device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unpin_memory(tensor):
|
||||||
|
if PerformanceFeature.PinnedMem not in args.fast:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_device_cpu(tensor.device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
|
|||||||
@ -238,6 +238,7 @@ class ModelPatcher:
|
|||||||
self.force_cast_weights = False
|
self.force_cast_weights = False
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
self.parent = None
|
self.parent = None
|
||||||
|
self.pinned = set()
|
||||||
|
|
||||||
self.attachments: dict[str] = {}
|
self.attachments: dict[str] = {}
|
||||||
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||||
@ -618,6 +619,21 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
|
def pin_weight_to_device(self, key):
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
if comfy.model_management.pin_memory(weight):
|
||||||
|
self.pinned.add(key)
|
||||||
|
|
||||||
|
def unpin_weight(self, key):
|
||||||
|
if key in self.pinned:
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
comfy.model_management.unpin_memory(weight)
|
||||||
|
self.pinned.remove(key)
|
||||||
|
|
||||||
|
def unpin_all_weights(self):
|
||||||
|
for key in list(self.pinned):
|
||||||
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self):
|
def _load_list(self):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
@ -642,6 +658,7 @@ class ModelPatcher:
|
|||||||
loading = self._load_list()
|
loading = self._load_list()
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
|
offloaded = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
@ -683,6 +700,7 @@ class ModelPatcher:
|
|||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
offloaded.append((module_mem, n, m, params))
|
||||||
else:
|
else:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
@ -713,7 +731,9 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
key = "{}.{}".format(n, param)
|
||||||
|
self.unpin_weight(key)
|
||||||
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
@ -721,6 +741,12 @@ class ModelPatcher:
|
|||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
x[2].to(device_to)
|
x[2].to(device_to)
|
||||||
|
|
||||||
|
for x in offloaded:
|
||||||
|
n = x[1]
|
||||||
|
params = x[3]
|
||||||
|
for param in params:
|
||||||
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
@ -762,6 +788,7 @@ class ModelPatcher:
|
|||||||
self.eject_model()
|
self.eject_model()
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
self.unpin_all_weights()
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
@ -857,6 +884,9 @@ class ModelPatcher:
|
|||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
|
for param in params:
|
||||||
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
@ -1259,5 +1289,6 @@ class ModelPatcher:
|
|||||||
self.clear_cached_hook_weights()
|
self.clear_cached_hook_weights()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
self.unpin_all_weights()
|
||||||
self.detach(unpatch_all=False)
|
self.detach(unpatch_all=False)
|
||||||
|
|
||||||
|
|||||||
257
comfy/ops.py
257
comfy/ops.py
@ -70,8 +70,12 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
|||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||||
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
|
# will add async-offload support to your cast and improve performance.
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
|
if offloadable:
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
else:
|
||||||
|
offload_stream = None
|
||||||
|
|
||||||
if offload_stream is not None:
|
if offload_stream is not None:
|
||||||
wf_context = offload_stream
|
wf_context = offload_stream
|
||||||
else:
|
else:
|
||||||
@ -105,8 +113,25 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
if offloadable:
|
||||||
|
return weight, bias, offload_stream
|
||||||
|
else:
|
||||||
|
#Legacy function signature
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
|
|
||||||
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||||
|
if offload_stream is None:
|
||||||
|
return
|
||||||
|
if weight is not None:
|
||||||
|
device = weight.device
|
||||||
|
else:
|
||||||
|
if bias is None:
|
||||||
|
return
|
||||||
|
device = bias.device
|
||||||
|
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
weight_function = []
|
weight_function = []
|
||||||
@ -118,8 +143,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
x = torch.nn.functional.linear(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -133,8 +160,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -148,8 +177,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -172,8 +203,10 @@ class disable_weight_init:
|
|||||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -187,8 +220,10 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -203,11 +238,14 @@ class disable_weight_init:
|
|||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
offload_stream = None
|
||||||
|
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -223,11 +261,15 @@ class disable_weight_init:
|
|||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
bias = None
|
||||||
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
offload_stream = None
|
||||||
|
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||||
|
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -246,10 +288,12 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.conv_transpose2d(
|
x = torch.nn.functional.conv_transpose2d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -268,10 +312,12 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.conv_transpose1d(
|
x = torch.nn.functional.conv_transpose1d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -289,8 +335,11 @@ class disable_weight_init:
|
|||||||
output_dtype = out_dtype
|
output_dtype = out_dtype
|
||||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||||
out_dtype = None
|
out_dtype = None
|
||||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
|
||||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
@ -344,6 +393,10 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
|
|
||||||
def fp8_linear(self, input):
|
def fp8_linear(self, input):
|
||||||
|
"""
|
||||||
|
Legacy FP8 linear function for backward compatibility.
|
||||||
|
Uses QuantizedTensor subclass for dispatch.
|
||||||
|
"""
|
||||||
dtype = self.weight.dtype
|
dtype = self.weight.dtype
|
||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
@ -355,9 +408,9 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
input_shape = input.shape
|
input_shape = input.shape
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
w = w.t()
|
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
@ -368,23 +421,20 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
||||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
|
|
||||||
if bias is not None:
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
else:
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||||
|
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
||||||
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
if isinstance(o, tuple):
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
o = o[0]
|
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input_shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -405,8 +455,10 @@ class fp8_ops(manual_cast):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info("Exception during fp8 op: {}".format(e))
|
logging.info("Exception during fp8 op: {}".format(e))
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
x = torch.nn.functional.linear(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||||
@ -434,12 +486,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
if out is not None:
|
if out is not None:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
|
|
||||||
if weight.numel() < input.numel(): #TODO: optimize
|
if weight.numel() < input.numel(): #TODO: optimize
|
||||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
if inplace:
|
if inplace:
|
||||||
@ -478,7 +532,130 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Mixed Precision Operations
|
||||||
|
# ==============================================================================
|
||||||
|
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||||
|
|
||||||
|
QUANT_FORMAT_MIXINS = {
|
||||||
|
"float8_e4m3fn": {
|
||||||
|
"dtype": torch.float8_e4m3fn,
|
||||||
|
"layout_type": TensorCoreFP8Layout,
|
||||||
|
"parameters": {
|
||||||
|
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
|
_layer_quant_config = {}
|
||||||
|
_compute_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||||
|
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
self.tensor_class = None
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
|
device = self.factory_kwargs["device"]
|
||||||
|
layer_name = prefix.rstrip('.')
|
||||||
|
weight_key = f"{prefix}weight"
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is None:
|
||||||
|
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||||
|
|
||||||
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
|
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||||
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
|
else:
|
||||||
|
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||||
|
if quant_format is None:
|
||||||
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
|
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||||
|
self.layout_type = mixin["layout_type"]
|
||||||
|
|
||||||
|
scale_key = f"{prefix}weight_scale"
|
||||||
|
layout_params = {
|
||||||
|
'scale': state_dict.pop(scale_key, None),
|
||||||
|
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||||
|
}
|
||||||
|
if layout_params['scale'] is not None:
|
||||||
|
manually_loaded_keys.append(scale_key)
|
||||||
|
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
||||||
|
requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for param_name, param_value in mixin["parameters"].items():
|
||||||
|
param_key = f"{prefix}{param_name}"
|
||||||
|
_v = state_dict.pop(param_key, None)
|
||||||
|
if _v is None:
|
||||||
|
continue
|
||||||
|
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||||
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
for key in manually_loaded_keys:
|
||||||
|
if key in missing_keys:
|
||||||
|
missing_keys.remove(key)
|
||||||
|
|
||||||
|
def _forward(self, input, weight, bias):
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
|
x = self._forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, input, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
|
|
||||||
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
|
getattr(self, 'input_scale', None) is not None and
|
||||||
|
not isinstance(input, QuantizedTensor)):
|
||||||
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
||||||
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||||
|
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||||
|
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||||
|
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||||
|
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||||
|
return MixedPrecisionOps
|
||||||
|
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||||
|
|||||||
437
comfy/quant_ops.py
Normal file
437
comfy/quant_ops.py
Normal file
@ -0,0 +1,437 @@
|
|||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from typing import Tuple, Dict
|
||||||
|
|
||||||
|
_LAYOUT_REGISTRY = {}
|
||||||
|
_GENERIC_UTILS = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_layout_op(torch_op, layout_type):
|
||||||
|
"""
|
||||||
|
Decorator to register a layout-specific operation handler.
|
||||||
|
Args:
|
||||||
|
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||||
|
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||||
|
Example:
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||||
|
def fp8_linear(func, args, kwargs):
|
||||||
|
# FP8-specific linear implementation
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(handler_func):
|
||||||
|
if torch_op not in _LAYOUT_REGISTRY:
|
||||||
|
_LAYOUT_REGISTRY[torch_op] = {}
|
||||||
|
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||||
|
return handler_func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_generic_util(torch_op):
|
||||||
|
"""
|
||||||
|
Decorator to register a generic utility that works for all layouts.
|
||||||
|
Args:
|
||||||
|
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@register_generic_util(torch.ops.aten.detach.default)
|
||||||
|
def generic_detach(func, args, kwargs):
|
||||||
|
# Works for any layout
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(handler_func):
|
||||||
|
_GENERIC_UTILS[torch_op] = handler_func
|
||||||
|
return handler_func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _get_layout_from_args(args):
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, QuantizedTensor):
|
||||||
|
return arg._layout_type
|
||||||
|
elif isinstance(arg, (list, tuple)):
|
||||||
|
for item in arg:
|
||||||
|
if isinstance(item, QuantizedTensor):
|
||||||
|
return item._layout_type
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _move_layout_params_to_device(params, device):
|
||||||
|
new_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_params[k] = v.to(device=device)
|
||||||
|
else:
|
||||||
|
new_params[k] = v
|
||||||
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_layout_params(params):
|
||||||
|
new_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_params[k] = v.clone()
|
||||||
|
else:
|
||||||
|
new_params[k] = v
|
||||||
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedLayout:
|
||||||
|
"""
|
||||||
|
Base class for quantization layouts.
|
||||||
|
|
||||||
|
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||||
|
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||||
|
|
||||||
|
New quantization formats should subclass this and implement the required methods.
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedTensor(torch.Tensor):
|
||||||
|
"""
|
||||||
|
Universal quantized tensor that works with any layout.
|
||||||
|
|
||||||
|
This tensor subclass uses a pluggable layout system to support multiple
|
||||||
|
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||||
|
|
||||||
|
The layout_type determines format-specific behavior, while common operations
|
||||||
|
(detach, clone, to) are handled generically.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_qdata: The quantized tensor data
|
||||||
|
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||||
|
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, qdata, layout_type, layout_params):
|
||||||
|
"""
|
||||||
|
Create a quantized tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qdata: The quantized data tensor
|
||||||
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
|
layout_params: Dict with layout-specific parameters
|
||||||
|
"""
|
||||||
|
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
||||||
|
|
||||||
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
|
self._qdata = qdata.contiguous()
|
||||||
|
self._layout_type = layout_type
|
||||||
|
self._layout_params = layout_params
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
layout_name = self._layout_type.__name__
|
||||||
|
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||||
|
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layout_type(self):
|
||||||
|
return self._layout_type
|
||||||
|
|
||||||
|
def __tensor_flatten__(self):
|
||||||
|
"""
|
||||||
|
Tensor flattening protocol for proper device movement.
|
||||||
|
"""
|
||||||
|
inner_tensors = ["_qdata"]
|
||||||
|
ctx = {
|
||||||
|
"layout_type": self._layout_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor_params = {}
|
||||||
|
non_tensor_params = {}
|
||||||
|
for k, v in self._layout_params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
tensor_params[k] = v
|
||||||
|
else:
|
||||||
|
non_tensor_params[k] = v
|
||||||
|
|
||||||
|
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||||
|
ctx["non_tensor_params"] = non_tensor_params
|
||||||
|
|
||||||
|
for k, v in tensor_params.items():
|
||||||
|
attr_name = f"_layout_param_{k}"
|
||||||
|
object.__setattr__(self, attr_name, v)
|
||||||
|
inner_tensors.append(attr_name)
|
||||||
|
|
||||||
|
return inner_tensors, ctx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||||
|
"""
|
||||||
|
Tensor unflattening protocol for proper device movement.
|
||||||
|
Reconstructs the QuantizedTensor after device movement.
|
||||||
|
"""
|
||||||
|
layout_type = ctx["layout_type"]
|
||||||
|
layout_params = dict(ctx["non_tensor_params"])
|
||||||
|
|
||||||
|
for key in ctx["tensor_param_keys"]:
|
||||||
|
attr_name = f"_layout_param_{key}"
|
||||||
|
layout_params[key] = inner_tensors[attr_name]
|
||||||
|
|
||||||
|
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||||
|
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
||||||
|
return cls(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
|
def dequantize(self) -> torch.Tensor:
|
||||||
|
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||||
|
if func in _GENERIC_UTILS:
|
||||||
|
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||||
|
|
||||||
|
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||||
|
layout_type = _get_layout_from_args(args)
|
||||||
|
if layout_type and func in _LAYOUT_REGISTRY:
|
||||||
|
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||||
|
if handler:
|
||||||
|
return handler(func, args, kwargs)
|
||||||
|
|
||||||
|
# Step 3: Fallback to dequantization
|
||||||
|
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||||
|
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||||
|
return cls._dequant_and_fallback(func, args, kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||||
|
def dequant_arg(arg):
|
||||||
|
if isinstance(arg, QuantizedTensor):
|
||||||
|
return arg.dequantize()
|
||||||
|
elif isinstance(arg, (list, tuple)):
|
||||||
|
return type(arg)(dequant_arg(a) for a in arg)
|
||||||
|
return arg
|
||||||
|
|
||||||
|
new_args = dequant_arg(args)
|
||||||
|
new_kwargs = dequant_arg(kwargs)
|
||||||
|
return func(*new_args, **new_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Generic Utilities (Layout-Agnostic Operations)
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
def _create_transformed_qtensor(qt, transform_fn):
|
||||||
|
new_data = transform_fn(qt._qdata)
|
||||||
|
new_params = _copy_layout_params(qt._layout_params)
|
||||||
|
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||||
|
if target_dtype is not None and target_dtype != qt.dtype:
|
||||||
|
logging.warning(
|
||||||
|
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||||
|
f"but not supported for quantized tensors. Ignoring dtype."
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_layout is not None and target_layout != torch.strided:
|
||||||
|
logging.warning(
|
||||||
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||||
|
f"but not supported. Ignoring layout."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle device transfer
|
||||||
|
current_device = qt._qdata.device
|
||||||
|
if target_device is not None:
|
||||||
|
# Normalize device for comparison
|
||||||
|
if isinstance(target_device, str):
|
||||||
|
target_device = torch.device(target_device)
|
||||||
|
if isinstance(current_device, str):
|
||||||
|
current_device = torch.device(current_device)
|
||||||
|
|
||||||
|
if target_device != current_device:
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||||
|
new_q_data = qt._qdata.to(device=target_device)
|
||||||
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||||
|
return new_qt
|
||||||
|
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||||
|
return qt
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.detach.default)
|
||||||
|
def generic_detach(func, args, kwargs):
|
||||||
|
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.clone.default)
|
||||||
|
def generic_clone(func, args, kwargs):
|
||||||
|
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||||
|
def generic_to_copy(func, args, kwargs):
|
||||||
|
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _handle_device_transfer(
|
||||||
|
qt,
|
||||||
|
target_device=kwargs.get('device', None),
|
||||||
|
target_dtype=kwargs.get('dtype', None),
|
||||||
|
op_name="_to_copy"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||||
|
def generic_to_dtype_layout(func, args, kwargs):
|
||||||
|
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _handle_device_transfer(
|
||||||
|
qt,
|
||||||
|
target_device=kwargs.get('device', None),
|
||||||
|
target_dtype=kwargs.get('dtype', None),
|
||||||
|
target_layout=kwargs.get('layout', None),
|
||||||
|
op_name="to"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.copy_.default)
|
||||||
|
def generic_copy_(func, args, kwargs):
|
||||||
|
qt_dest = args[0]
|
||||||
|
src = args[1]
|
||||||
|
|
||||||
|
if isinstance(qt_dest, QuantizedTensor):
|
||||||
|
if isinstance(src, QuantizedTensor):
|
||||||
|
# Copy from another quantized tensor
|
||||||
|
qt_dest._qdata.copy_(src._qdata)
|
||||||
|
qt_dest._layout_type = src._layout_type
|
||||||
|
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
||||||
|
else:
|
||||||
|
# Copy from regular tensor - just copy raw data
|
||||||
|
qt_dest._qdata.copy_(src)
|
||||||
|
return qt_dest
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||||
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# FP8 Layout + Operation Handlers
|
||||||
|
# ==============================================================================
|
||||||
|
class TensorCoreFP8Layout(QuantizedLayout):
|
||||||
|
"""
|
||||||
|
Storage format:
|
||||||
|
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||||
|
- scale: Scalar tensor (float32) for dequantization
|
||||||
|
- orig_dtype: Original dtype before quantization (for casting back)
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||||
|
|
||||||
|
if not isinstance(scale, torch.Tensor):
|
||||||
|
scale = torch.tensor(scale)
|
||||||
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
lp_amax = torch.finfo(dtype).max
|
||||||
|
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
|
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||||
|
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
layout_params = {
|
||||||
|
'scale': scale,
|
||||||
|
'orig_dtype': orig_dtype
|
||||||
|
}
|
||||||
|
return qdata, layout_params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||||
|
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||||
|
return plain_tensor * scale
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor):
|
||||||
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||||
|
def fp8_linear(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
bias = args[2] if len(args) > 2 else None
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||||
|
|
||||||
|
out_dtype = kwargs.get("out_dtype")
|
||||||
|
if out_dtype is None:
|
||||||
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||||
|
|
||||||
|
weight_t = plain_weight.t()
|
||||||
|
|
||||||
|
tensor_2d = False
|
||||||
|
if len(plain_input.shape) == 2:
|
||||||
|
tensor_2d = True
|
||||||
|
plain_input = plain_input.unsqueeze(1)
|
||||||
|
|
||||||
|
input_shape = plain_input.shape
|
||||||
|
if len(input_shape) != 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
plain_input.reshape(-1, input_shape[2]),
|
||||||
|
weight_t,
|
||||||
|
bias=bias,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
)
|
||||||
|
if not tensor_2d:
|
||||||
|
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||||
|
|
||||||
|
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
|
output_scale = scale_a * scale_b
|
||||||
|
output_params = {
|
||||||
|
'scale': output_scale,
|
||||||
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||||
|
}
|
||||||
|
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||||
|
|
||||||
|
# Case 2: DQ Fallback
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight = weight.dequantize()
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
input_tensor = input_tensor.dequantize()
|
||||||
|
|
||||||
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
11
comfy/sd.py
11
comfy/sd.py
@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||||
"""
|
"""
|
||||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
@ -1330,6 +1330,9 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
|
if model_config.layer_quant_config is not None:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
|
else:
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||||
@ -1346,8 +1349,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
|
|
||||||
|
|
||||||
def load_diffusion_model(unet_path, model_options={}):
|
def load_diffusion_model(unet_path, model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class BASE:
|
|||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
scaled_fp8 = None
|
scaled_fp8 = None
|
||||||
|
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||||
optimizations = {"fp8": False}
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -3,14 +3,6 @@ import aiohttp
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from comfy.utils import common_upscale
|
from comfy.utils import common_upscale
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiClient,
|
|
||||||
ApiEndpoint,
|
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
UploadRequest,
|
|
||||||
UploadResponse,
|
|
||||||
)
|
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
@ -19,7 +11,6 @@ from PIL import Image
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import base64
|
import base64
|
||||||
from .util import tensor_to_bytesio, bytesio_to_image_tensor
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
@ -148,11 +139,6 @@ async def download_url_to_bytesio(
|
|||||||
return BytesIO(await resp.read())
|
return BytesIO(await resp.read())
|
||||||
|
|
||||||
|
|
||||||
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
|
||||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
|
||||||
return bytesio_to_image_tensor(BytesIO(response_content))
|
|
||||||
|
|
||||||
|
|
||||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||||
"""Converts a text file to a base64 string."""
|
"""Converts a text file to a base64 string."""
|
||||||
with open(filepath, "rb") as f:
|
with open(filepath, "rb") as f:
|
||||||
@ -169,73 +155,6 @@ def text_filepath_to_data_uri(filepath: str) -> str:
|
|||||||
return f"data:{mime_type};base64,{base64_string}"
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|
||||||
|
|
||||||
async def upload_file_to_comfyapi(
|
|
||||||
file_bytes_io: BytesIO,
|
|
||||||
filename: str,
|
|
||||||
upload_mime_type: Optional[str],
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Uploads a single file to ComfyUI API and returns its download URL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_bytes_io: BytesIO object containing the file data.
|
|
||||||
filename: The filename of the file.
|
|
||||||
upload_mime_type: MIME type of the file.
|
|
||||||
auth_kwargs: Optional authentication token(s).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The download URL for the uploaded file.
|
|
||||||
"""
|
|
||||||
if upload_mime_type is None:
|
|
||||||
request_object = UploadRequest(file_name=filename)
|
|
||||||
else:
|
|
||||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/customers/storage",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=UploadRequest,
|
|
||||||
response_model=UploadResponse,
|
|
||||||
),
|
|
||||||
request=request_object,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
response: UploadResponse = await operation.execute()
|
|
||||||
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
|
|
||||||
return response.download_url
|
|
||||||
|
|
||||||
|
|
||||||
async def upload_images_to_comfyapi(
|
|
||||||
image: torch.Tensor,
|
|
||||||
max_images=8,
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
mime_type: Optional[str] = None,
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Uploads images to ComfyUI API and returns download URLs.
|
|
||||||
To upload multiple images, stack them in the batch dimension first.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: Input torch.Tensor image.
|
|
||||||
max_images: Maximum number of images to upload.
|
|
||||||
auth_kwargs: Optional authentication token(s).
|
|
||||||
mime_type: Optional MIME type for the image.
|
|
||||||
"""
|
|
||||||
# if batch, try to upload each file if max_images is greater than 0
|
|
||||||
download_urls: list[str] = []
|
|
||||||
is_batch = len(image.shape) > 3
|
|
||||||
batch_len = image.shape[0] if is_batch else 1
|
|
||||||
|
|
||||||
for idx in range(min(batch_len, max_images)):
|
|
||||||
tensor = image[idx] if is_batch else image
|
|
||||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
|
||||||
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
|
|
||||||
download_urls.append(url)
|
|
||||||
return download_urls
|
|
||||||
|
|
||||||
|
|
||||||
def resize_mask_to_image(
|
def resize_mask_to_image(
|
||||||
mask: torch.Tensor,
|
mask: torch.Tensor,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
|
|||||||
120
comfy_api_nodes/apis/minimax_api.py
Normal file
120
comfy_api_nodes/apis/minimax_api.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxBaseResponse(BaseModel):
|
||||||
|
status_code: int = Field(
|
||||||
|
...,
|
||||||
|
description='Status code. 0 indicates success, other values indicate errors.',
|
||||||
|
)
|
||||||
|
status_msg: str = Field(
|
||||||
|
..., description='Specific error details or success message.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class File(BaseModel):
|
||||||
|
bytes: Optional[int] = Field(None, description='File size in bytes')
|
||||||
|
created_at: Optional[int] = Field(
|
||||||
|
None, description='Unix timestamp when the file was created, in seconds'
|
||||||
|
)
|
||||||
|
download_url: Optional[str] = Field(
|
||||||
|
None, description='The URL to download the video'
|
||||||
|
)
|
||||||
|
backup_download_url: Optional[str] = Field(
|
||||||
|
None, description='The backup URL to download the video'
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
|
||||||
|
filename: Optional[str] = Field(None, description='The name of the file')
|
||||||
|
purpose: Optional[str] = Field(None, description='The purpose of using the file')
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxFileRetrieveResponse(BaseModel):
|
||||||
|
base_resp: MinimaxBaseResponse
|
||||||
|
file: File
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxModel(str, Enum):
|
||||||
|
T2V_01_Director = 'T2V-01-Director'
|
||||||
|
I2V_01_Director = 'I2V-01-Director'
|
||||||
|
S2V_01 = 'S2V-01'
|
||||||
|
I2V_01 = 'I2V-01'
|
||||||
|
I2V_01_live = 'I2V-01-live'
|
||||||
|
T2V_01 = 'T2V-01'
|
||||||
|
Hailuo_02 = 'MiniMax-Hailuo-02'
|
||||||
|
|
||||||
|
|
||||||
|
class Status6(str, Enum):
|
||||||
|
Queueing = 'Queueing'
|
||||||
|
Preparing = 'Preparing'
|
||||||
|
Processing = 'Processing'
|
||||||
|
Success = 'Success'
|
||||||
|
Fail = 'Fail'
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxTaskResultResponse(BaseModel):
|
||||||
|
base_resp: MinimaxBaseResponse
|
||||||
|
file_id: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
|
||||||
|
)
|
||||||
|
status: Status6 = Field(
|
||||||
|
...,
|
||||||
|
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
|
||||||
|
)
|
||||||
|
task_id: str = Field(..., description='The task ID being queried.')
|
||||||
|
|
||||||
|
|
||||||
|
class SubjectReferenceItem(BaseModel):
|
||||||
|
image: Optional[str] = Field(
|
||||||
|
None, description='URL or base64 encoding of the subject reference image.'
|
||||||
|
)
|
||||||
|
mask: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='URL or base64 encoding of the mask for the subject reference image.',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxVideoGenerationRequest(BaseModel):
|
||||||
|
callback_url: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='Optional. URL to receive real-time status updates about the video generation task.',
|
||||||
|
)
|
||||||
|
first_frame_image: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
|
||||||
|
)
|
||||||
|
model: MiniMaxModel = Field(
|
||||||
|
...,
|
||||||
|
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
|
||||||
|
)
|
||||||
|
prompt: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
|
||||||
|
max_length=2000,
|
||||||
|
)
|
||||||
|
prompt_optimizer: Optional[bool] = Field(
|
||||||
|
True,
|
||||||
|
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
|
||||||
|
)
|
||||||
|
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
|
||||||
|
None,
|
||||||
|
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
|
||||||
|
)
|
||||||
|
duration: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description="The length of the output video in seconds."
|
||||||
|
)
|
||||||
|
resolution: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxVideoGenerationResponse(BaseModel):
|
||||||
|
base_resp: MinimaxBaseResponse
|
||||||
|
task_id: str = Field(
|
||||||
|
..., description='The task ID for the asynchronous video generation task.'
|
||||||
|
)
|
||||||
@ -20,9 +20,9 @@ from comfy_api_nodes.apis.client import (
|
|||||||
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
from comfy_api_nodes.apinode_utils import (
|
||||||
download_url_to_bytesio,
|
download_url_to_bytesio,
|
||||||
bytesio_to_image_tensor,
|
|
||||||
resize_mask_to_image,
|
resize_mask_to_image,
|
||||||
)
|
)
|
||||||
|
from comfy_api_nodes.util import bytesio_to_image_tensor
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
|
|
||||||
V1_V1_RES_MAP = {
|
V1_V1_RES_MAP = {
|
||||||
|
|||||||
@ -1,69 +1,51 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from inspect import cleandoc
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
|
||||||
from comfy_api.input_impl.video_types import VideoFromFile
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apis.luma_api import (
|
from comfy_api_nodes.apis.luma_api import (
|
||||||
LumaImageModel,
|
|
||||||
LumaVideoModel,
|
|
||||||
LumaVideoOutputResolution,
|
|
||||||
LumaVideoModelOutputDuration,
|
|
||||||
LumaAspectRatio,
|
LumaAspectRatio,
|
||||||
LumaState,
|
|
||||||
LumaImageGenerationRequest,
|
|
||||||
LumaGenerationRequest,
|
|
||||||
LumaGeneration,
|
|
||||||
LumaCharacterRef,
|
LumaCharacterRef,
|
||||||
LumaModifyImageRef,
|
LumaConceptChain,
|
||||||
|
LumaGeneration,
|
||||||
|
LumaGenerationRequest,
|
||||||
|
LumaImageGenerationRequest,
|
||||||
LumaImageIdentity,
|
LumaImageIdentity,
|
||||||
|
LumaImageModel,
|
||||||
|
LumaImageReference,
|
||||||
|
LumaIO,
|
||||||
|
LumaKeyframes,
|
||||||
|
LumaModifyImageRef,
|
||||||
LumaReference,
|
LumaReference,
|
||||||
LumaReferenceChain,
|
LumaReferenceChain,
|
||||||
LumaImageReference,
|
LumaVideoModel,
|
||||||
LumaKeyframes,
|
LumaVideoModelOutputDuration,
|
||||||
LumaConceptChain,
|
LumaVideoOutputResolution,
|
||||||
LumaIO,
|
|
||||||
get_luma_concepts,
|
get_luma_concepts,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_image_tensor,
|
||||||
SynchronousOperation,
|
download_url_to_video_output,
|
||||||
PollingOperation,
|
poll_op,
|
||||||
EmptyRequest,
|
sync_op,
|
||||||
)
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
process_image_response,
|
validate_string,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
|
||||||
from comfy_api_nodes.util import validate_string
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import torch
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
LUMA_T2V_AVERAGE_DURATION = 105
|
LUMA_T2V_AVERAGE_DURATION = 105
|
||||||
LUMA_I2V_AVERAGE_DURATION = 100
|
LUMA_I2V_AVERAGE_DURATION = 100
|
||||||
|
|
||||||
def image_result_url_extractor(response: LumaGeneration):
|
|
||||||
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
|
|
||||||
|
|
||||||
def video_result_url_extractor(response: LumaGeneration):
|
|
||||||
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
|
|
||||||
|
|
||||||
class LumaReferenceNode(IO.ComfyNode):
|
class LumaReferenceNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Holds an image and weight for use with Luma Generate Image node.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaReferenceNode",
|
node_id="LumaReferenceNode",
|
||||||
display_name="Luma Reference",
|
display_name="Luma Reference",
|
||||||
category="api node/image/Luma",
|
category="api node/image/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Holds an image and weight for use with Luma Generate Image node.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
@ -83,17 +65,10 @@ class LumaReferenceNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(
|
def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
|
||||||
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
if luma_ref is not None:
|
if luma_ref is not None:
|
||||||
luma_ref = luma_ref.clone()
|
luma_ref = luma_ref.clone()
|
||||||
else:
|
else:
|
||||||
@ -103,17 +78,13 @@ class LumaReferenceNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class LumaConceptsNode(IO.ComfyNode):
|
class LumaConceptsNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaConceptsNode",
|
node_id="LumaConceptsNode",
|
||||||
display_name="Luma Concepts",
|
display_name="Luma Concepts",
|
||||||
category="api node/video/Luma",
|
category="api node/video/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"concept1",
|
"concept1",
|
||||||
@ -138,11 +109,6 @@ class LumaConceptsNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -161,17 +127,13 @@ class LumaConceptsNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class LumaImageGenerationNode(IO.ComfyNode):
|
class LumaImageGenerationNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images synchronously based on prompt and aspect ratio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaImageNode",
|
node_id="LumaImageNode",
|
||||||
display_name="Luma Text to Image",
|
display_name="Luma Text to Image",
|
||||||
category="api node/image/Luma",
|
category="api node/image/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images synchronously based on prompt and aspect ratio.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -237,45 +199,30 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
seed,
|
seed,
|
||||||
style_image_weight: float,
|
style_image_weight: float,
|
||||||
image_luma_ref: LumaReferenceChain = None,
|
image_luma_ref: Optional[LumaReferenceChain] = None,
|
||||||
style_image: torch.Tensor = None,
|
style_image: Optional[torch.Tensor] = None,
|
||||||
character_image: torch.Tensor = None,
|
character_image: Optional[torch.Tensor] = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||||
auth_kwargs = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
# handle image_luma_ref
|
# handle image_luma_ref
|
||||||
api_image_ref = None
|
api_image_ref = None
|
||||||
if image_luma_ref is not None:
|
if image_luma_ref is not None:
|
||||||
api_image_ref = await cls._convert_luma_refs(
|
api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
|
||||||
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
# handle style_luma_ref
|
# handle style_luma_ref
|
||||||
api_style_ref = None
|
api_style_ref = None
|
||||||
if style_image is not None:
|
if style_image is not None:
|
||||||
api_style_ref = await cls._convert_style_image(
|
api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
|
||||||
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
# handle character_ref images
|
# handle character_ref images
|
||||||
character_ref = None
|
character_ref = None
|
||||||
if character_image is not None:
|
if character_image is not None:
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
|
||||||
character_image, max_images=4, auth_kwargs=auth_kwargs,
|
character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
|
||||||
)
|
|
||||||
character_ref = LumaCharacterRef(
|
|
||||||
identity0=LumaImageIdentity(images=download_urls)
|
|
||||||
)
|
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/luma/generations/image",
|
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=LumaImageGenerationRequest,
|
|
||||||
response_model=LumaGeneration,
|
response_model=LumaGeneration,
|
||||||
),
|
data=LumaImageGenerationRequest(
|
||||||
request=LumaImageGenerationRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -283,41 +230,21 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
|||||||
style_ref=api_style_ref,
|
style_ref=api_style_ref,
|
||||||
character_ref=character_ref,
|
character_ref=character_ref,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = await operation.execute()
|
response_poll = await poll_op(
|
||||||
|
cls,
|
||||||
operation = PollingOperation(
|
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path=f"/proxy/luma/generations/{response_api.id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=LumaGeneration,
|
response_model=LumaGeneration,
|
||||||
),
|
|
||||||
completed_statuses=[LumaState.completed],
|
|
||||||
failed_statuses=[LumaState.failed],
|
|
||||||
status_extractor=lambda x: x.state,
|
status_extractor=lambda x: x.state,
|
||||||
result_url_extractor=image_result_url_extractor,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.assets.image) as img_response:
|
|
||||||
img = process_image_response(await img_response.content.read())
|
|
||||||
return IO.NodeOutput(img)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _convert_luma_refs(
|
async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
|
||||||
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
|
||||||
):
|
|
||||||
luma_urls = []
|
luma_urls = []
|
||||||
ref_count = 0
|
ref_count = 0
|
||||||
for ref in luma_ref.refs:
|
for ref in luma_ref.refs:
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
|
||||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
|
||||||
)
|
|
||||||
luma_urls.append(download_urls[0])
|
luma_urls.append(download_urls[0])
|
||||||
ref_count += 1
|
ref_count += 1
|
||||||
if ref_count >= max_refs:
|
if ref_count >= max_refs:
|
||||||
@ -325,27 +252,19 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
|||||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _convert_style_image(
|
async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
|
||||||
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
|
||||||
):
|
return await cls._convert_luma_refs(chain, max_refs=1)
|
||||||
chain = LumaReferenceChain(
|
|
||||||
first_ref=LumaReference(image=style_image, weight=weight)
|
|
||||||
)
|
|
||||||
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LumaImageModifyNode(IO.ComfyNode):
|
class LumaImageModifyNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Modifies images synchronously based on prompt and aspect ratio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaImageModifyNode",
|
node_id="LumaImageModifyNode",
|
||||||
display_name="Luma Image to Image",
|
display_name="Luma Image to Image",
|
||||||
category="api node/image/Luma",
|
category="api node/image/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Modifies images synchronously based on prompt and aspect ratio.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
@ -395,68 +314,37 @@ class LumaImageModifyNode(IO.ComfyNode):
|
|||||||
image_weight: float,
|
image_weight: float,
|
||||||
seed,
|
seed,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth_kwargs = {
|
download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
# first, upload image
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
|
||||||
image, max_images=1, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
image_url = download_urls[0]
|
image_url = download_urls[0]
|
||||||
# next, make Luma call with download url provided
|
response_api = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||||
path="/proxy/luma/generations/image",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=LumaImageGenerationRequest,
|
|
||||||
response_model=LumaGeneration,
|
response_model=LumaGeneration,
|
||||||
),
|
data=LumaImageGenerationRequest(
|
||||||
request=LumaImageGenerationRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
modify_image_ref=LumaModifyImageRef(
|
modify_image_ref=LumaModifyImageRef(
|
||||||
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
|
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = await operation.execute()
|
response_poll = await poll_op(
|
||||||
|
cls,
|
||||||
operation = PollingOperation(
|
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path=f"/proxy/luma/generations/{response_api.id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=LumaGeneration,
|
response_model=LumaGeneration,
|
||||||
),
|
|
||||||
completed_statuses=[LumaState.completed],
|
|
||||||
failed_statuses=[LumaState.failed],
|
|
||||||
status_extractor=lambda x: x.state,
|
status_extractor=lambda x: x.state,
|
||||||
result_url_extractor=image_result_url_extractor,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.assets.image) as img_response:
|
|
||||||
img = process_image_response(await img_response.content.read())
|
|
||||||
return IO.NodeOutput(img)
|
|
||||||
|
|
||||||
|
|
||||||
class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaVideoNode",
|
node_id="LumaVideoNode",
|
||||||
display_name="Luma Text to Video",
|
display_name="Luma Text to Video",
|
||||||
category="api node/video/Luma",
|
category="api node/video/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -498,7 +386,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
|||||||
"luma_concepts",
|
"luma_concepts",
|
||||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||||
optional=True,
|
optional=True,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
outputs=[IO.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
|||||||
duration: str,
|
duration: str,
|
||||||
loop: bool,
|
loop: bool,
|
||||||
seed,
|
seed,
|
||||||
luma_concepts: LumaConceptChain = None,
|
luma_concepts: Optional[LumaConceptChain] = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||||
|
|
||||||
auth_kwargs = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||||
}
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/luma/generations",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=LumaGenerationRequest,
|
|
||||||
response_model=LumaGeneration,
|
response_model=LumaGeneration,
|
||||||
),
|
data=LumaGenerationRequest(
|
||||||
request=LumaGenerationRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
@ -545,47 +426,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
|||||||
loop=loop,
|
loop=loop,
|
||||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = await operation.execute()
|
response_poll = await poll_op(
|
||||||
|
cls,
|
||||||
if cls.hidden.unique_id:
|
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
|
||||||
|
|
||||||
operation = PollingOperation(
|
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path=f"/proxy/luma/generations/{response_api.id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=LumaGeneration,
|
response_model=LumaGeneration,
|
||||||
),
|
|
||||||
completed_statuses=[LumaState.completed],
|
|
||||||
failed_statuses=[LumaState.failed],
|
|
||||||
status_extractor=lambda x: x.state,
|
status_extractor=lambda x: x.state,
|
||||||
result_url_extractor=video_result_url_extractor,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.assets.video) as vid_response:
|
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on prompt, input images, and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaImageToVideoNode",
|
node_id="LumaImageToVideoNode",
|
||||||
display_name="Luma Image to Video",
|
display_name="Luma Image to Video",
|
||||||
category="api node/video/Luma",
|
category="api node/video/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on prompt, input images, and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -637,7 +496,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
|||||||
"luma_concepts",
|
"luma_concepts",
|
||||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||||
optional=True,
|
optional=True,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
outputs=[IO.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -662,25 +521,15 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
|||||||
luma_concepts: LumaConceptChain = None,
|
luma_concepts: LumaConceptChain = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if first_image is None and last_image is None:
|
if first_image is None and last_image is None:
|
||||||
raise Exception(
|
raise Exception("At least one of first_image and last_image requires an input.")
|
||||||
"At least one of first_image and last_image requires an input."
|
keyframes = await cls._convert_to_keyframes(first_image, last_image)
|
||||||
)
|
|
||||||
auth_kwargs = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
|
|
||||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||||
|
response_api = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||||
path="/proxy/luma/generations",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=LumaGenerationRequest,
|
|
||||||
response_model=LumaGeneration,
|
response_model=LumaGeneration,
|
||||||
),
|
data=LumaGenerationRequest(
|
||||||
request=LumaGenerationRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
||||||
@ -690,54 +539,31 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
|||||||
keyframes=keyframes,
|
keyframes=keyframes,
|
||||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = await operation.execute()
|
response_poll = await poll_op(
|
||||||
|
cls,
|
||||||
if cls.hidden.unique_id:
|
poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
|
||||||
|
|
||||||
operation = PollingOperation(
|
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path=f"/proxy/luma/generations/{response_api.id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=LumaGeneration,
|
response_model=LumaGeneration,
|
||||||
),
|
|
||||||
completed_statuses=[LumaState.completed],
|
|
||||||
failed_statuses=[LumaState.failed],
|
|
||||||
status_extractor=lambda x: x.state,
|
status_extractor=lambda x: x.state,
|
||||||
result_url_extractor=video_result_url_extractor,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.assets.video) as vid_response:
|
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _convert_to_keyframes(
|
async def _convert_to_keyframes(
|
||||||
cls,
|
cls,
|
||||||
first_image: torch.Tensor = None,
|
first_image: torch.Tensor = None,
|
||||||
last_image: torch.Tensor = None,
|
last_image: torch.Tensor = None,
|
||||||
auth_kwargs: Optional[dict[str,str]] = None,
|
|
||||||
):
|
):
|
||||||
if first_image is None and last_image is None:
|
if first_image is None and last_image is None:
|
||||||
return None
|
return None
|
||||||
frame0 = None
|
frame0 = None
|
||||||
frame1 = None
|
frame1 = None
|
||||||
if first_image is not None:
|
if first_image is not None:
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
|
||||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||||
if last_image is not None:
|
if last_image is not None:
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
|
||||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||||
|
|
||||||
|
|||||||
@ -1,42 +1,33 @@
|
|||||||
from inspect import cleandoc
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import logging
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
|
||||||
from comfy_api.input_impl.video_types import VideoFromFile
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis.minimax_api import (
|
||||||
|
MinimaxFileRetrieveResponse,
|
||||||
|
MiniMaxModel,
|
||||||
|
MinimaxTaskResultResponse,
|
||||||
MinimaxVideoGenerationRequest,
|
MinimaxVideoGenerationRequest,
|
||||||
MinimaxVideoGenerationResponse,
|
MinimaxVideoGenerationResponse,
|
||||||
MinimaxFileRetrieveResponse,
|
|
||||||
MinimaxTaskResultResponse,
|
|
||||||
SubjectReferenceItem,
|
SubjectReferenceItem,
|
||||||
MiniMaxModel,
|
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_video_output,
|
||||||
SynchronousOperation,
|
poll_op,
|
||||||
PollingOperation,
|
sync_op,
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
download_url_to_bytesio,
|
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
|
validate_string,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import validate_string
|
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
|
|
||||||
I2V_AVERAGE_DURATION = 114
|
I2V_AVERAGE_DURATION = 114
|
||||||
T2V_AVERAGE_DURATION = 234
|
T2V_AVERAGE_DURATION = 234
|
||||||
|
|
||||||
|
|
||||||
async def _generate_mm_video(
|
async def _generate_mm_video(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
*,
|
*,
|
||||||
auth: dict[str, str],
|
|
||||||
node_id: str,
|
|
||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
model: str,
|
model: str,
|
||||||
@ -46,26 +37,21 @@ async def _generate_mm_video(
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if image is None:
|
if image is None:
|
||||||
validate_string(prompt_text, field_name="prompt_text")
|
validate_string(prompt_text, field_name="prompt_text")
|
||||||
# upload image, if passed in
|
|
||||||
image_url = None
|
image_url = None
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
|
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||||
|
|
||||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||||
subject_reference = None
|
subject_reference = None
|
||||||
if subject is not None:
|
if subject is not None:
|
||||||
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
|
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
|
||||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
video_generate_operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||||
path="/proxy/minimax/video_generation",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=MinimaxVideoGenerationRequest,
|
|
||||||
response_model=MinimaxVideoGenerationResponse,
|
response_model=MinimaxVideoGenerationResponse,
|
||||||
),
|
data=MinimaxVideoGenerationRequest(
|
||||||
request=MinimaxVideoGenerationRequest(
|
|
||||||
model=MiniMaxModel(model),
|
model=MiniMaxModel(model),
|
||||||
prompt=prompt_text,
|
prompt=prompt_text,
|
||||||
callback_url=None,
|
callback_url=None,
|
||||||
@ -73,81 +59,50 @@ async def _generate_mm_video(
|
|||||||
subject_reference=subject_reference,
|
subject_reference=subject_reference,
|
||||||
prompt_optimizer=None,
|
prompt_optimizer=None,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response = await video_generate_operation.execute()
|
|
||||||
|
|
||||||
task_id = response.task_id
|
task_id = response.task_id
|
||||||
if not task_id:
|
if not task_id:
|
||||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||||
|
|
||||||
video_generate_operation = PollingOperation(
|
task_result = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/query/video_generation",
|
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=MinimaxTaskResultResponse,
|
response_model=MinimaxTaskResultResponse,
|
||||||
query_params={"task_id": task_id},
|
|
||||||
),
|
|
||||||
completed_statuses=["Success"],
|
|
||||||
failed_statuses=["Fail"],
|
|
||||||
status_extractor=lambda x: x.status.value,
|
status_extractor=lambda x: x.status.value,
|
||||||
estimated_duration=average_duration,
|
estimated_duration=average_duration,
|
||||||
node_id=node_id,
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
task_result = await video_generate_operation.execute()
|
|
||||||
|
|
||||||
file_id = task_result.file_id
|
file_id = task_result.file_id
|
||||||
if file_id is None:
|
if file_id is None:
|
||||||
raise Exception("Request was not successful. Missing file ID.")
|
raise Exception("Request was not successful. Missing file ID.")
|
||||||
file_retrieve_operation = SynchronousOperation(
|
file_result = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/files/retrieve",
|
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=MinimaxFileRetrieveResponse,
|
response_model=MinimaxFileRetrieveResponse,
|
||||||
query_params={"file_id": int(file_id)},
|
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
file_result = await file_retrieve_operation.execute()
|
|
||||||
|
|
||||||
file_url = file_result.file.download_url
|
file_url = file_result.file.download_url
|
||||||
if file_url is None:
|
if file_url is None:
|
||||||
raise Exception(
|
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
if file_result.file.backup_download_url:
|
||||||
|
try:
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||||
|
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||||
|
return IO.NodeOutput(
|
||||||
|
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||||
)
|
)
|
||||||
logging.info("Generated video URL: %s", file_url)
|
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||||
if node_id:
|
|
||||||
if hasattr(file_result.file, "backup_download_url"):
|
|
||||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
|
||||||
else:
|
|
||||||
message = f"Result URL: {file_url}"
|
|
||||||
PromptServer.instance.send_progress_text(message, node_id)
|
|
||||||
|
|
||||||
# Download and return as VideoFromFile
|
|
||||||
video_io = await download_url_to_bytesio(file_url)
|
|
||||||
if video_io is None:
|
|
||||||
error_msg = f"Failed to download video from {file_url}"
|
|
||||||
logging.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
return IO.NodeOutput(VideoFromFile(video_io))
|
|
||||||
|
|
||||||
|
|
||||||
class MinimaxTextToVideoNode(IO.ComfyNode):
|
class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="MinimaxTextToVideoNode",
|
node_id="MinimaxTextToVideoNode",
|
||||||
display_name="MiniMax Text to Video",
|
display_name="MiniMax Text to Video",
|
||||||
category="api node/video/MiniMax",
|
category="api node/video/MiniMax",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on a prompt, and optional parameters.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt_text",
|
"prompt_text",
|
||||||
@ -189,11 +144,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
|
|||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
return await _generate_mm_video(
|
return await _generate_mm_video(
|
||||||
auth={
|
cls,
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model=model,
|
model=model,
|
||||||
@ -204,17 +155,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class MinimaxImageToVideoNode(IO.ComfyNode):
|
class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="MinimaxImageToVideoNode",
|
node_id="MinimaxImageToVideoNode",
|
||||||
display_name="MiniMax Image to Video",
|
display_name="MiniMax Image to Video",
|
||||||
category="api node/video/MiniMax",
|
category="api node/video/MiniMax",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
@ -261,11 +208,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
|
|||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
return await _generate_mm_video(
|
return await _generate_mm_video(
|
||||||
auth={
|
cls,
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model=model,
|
model=model,
|
||||||
@ -276,17 +219,13 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="MinimaxSubjectToVideoNode",
|
node_id="MinimaxSubjectToVideoNode",
|
||||||
display_name="MiniMax Subject to Video",
|
display_name="MiniMax Subject to Video",
|
||||||
category="api node/video/MiniMax",
|
category="api node/video/MiniMax",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"subject",
|
"subject",
|
||||||
@ -333,11 +272,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
|||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
return await _generate_mm_video(
|
return await _generate_mm_video(
|
||||||
auth={
|
cls,
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model=model,
|
model=model,
|
||||||
@ -348,15 +283,13 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class MinimaxHailuoVideoNode(IO.ComfyNode):
|
class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||||
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="MinimaxHailuoVideoNode",
|
node_id="MinimaxHailuoVideoNode",
|
||||||
display_name="MiniMax Hailuo Video",
|
display_name="MiniMax Hailuo Video",
|
||||||
category="api node/video/MiniMax",
|
category="api node/video/MiniMax",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt_text",
|
"prompt_text",
|
||||||
@ -420,10 +353,6 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
|||||||
resolution: str = "768P",
|
resolution: str = "768P",
|
||||||
model: str = "MiniMax-Hailuo-02",
|
model: str = "MiniMax-Hailuo-02",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
if first_frame_image is None:
|
if first_frame_image is None:
|
||||||
validate_string(prompt_text, field_name="prompt_text")
|
validate_string(prompt_text, field_name="prompt_text")
|
||||||
|
|
||||||
@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
|||||||
# upload image, if passed in
|
# upload image, if passed in
|
||||||
image_url = None
|
image_url = None
|
||||||
if first_frame_image is not None:
|
if first_frame_image is not None:
|
||||||
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
|
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
|
||||||
|
|
||||||
video_generate_operation = SynchronousOperation(
|
response = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/video_generation",
|
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=MinimaxVideoGenerationRequest,
|
|
||||||
response_model=MinimaxVideoGenerationResponse,
|
response_model=MinimaxVideoGenerationResponse,
|
||||||
),
|
data=MinimaxVideoGenerationRequest(
|
||||||
request=MinimaxVideoGenerationRequest(
|
|
||||||
model=MiniMaxModel(model),
|
model=MiniMaxModel(model),
|
||||||
prompt=prompt_text,
|
prompt=prompt_text,
|
||||||
callback_url=None,
|
callback_url=None,
|
||||||
@ -453,67 +379,42 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response = await video_generate_operation.execute()
|
|
||||||
|
|
||||||
task_id = response.task_id
|
task_id = response.task_id
|
||||||
if not task_id:
|
if not task_id:
|
||||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||||
|
|
||||||
average_duration = 120 if resolution == "768P" else 240
|
average_duration = 120 if resolution == "768P" else 240
|
||||||
video_generate_operation = PollingOperation(
|
task_result = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/query/video_generation",
|
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=MinimaxTaskResultResponse,
|
response_model=MinimaxTaskResultResponse,
|
||||||
query_params={"task_id": task_id},
|
|
||||||
),
|
|
||||||
completed_statuses=["Success"],
|
|
||||||
failed_statuses=["Fail"],
|
|
||||||
status_extractor=lambda x: x.status.value,
|
status_extractor=lambda x: x.status.value,
|
||||||
estimated_duration=average_duration,
|
estimated_duration=average_duration,
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
task_result = await video_generate_operation.execute()
|
|
||||||
|
|
||||||
file_id = task_result.file_id
|
file_id = task_result.file_id
|
||||||
if file_id is None:
|
if file_id is None:
|
||||||
raise Exception("Request was not successful. Missing file ID.")
|
raise Exception("Request was not successful. Missing file ID.")
|
||||||
file_retrieve_operation = SynchronousOperation(
|
file_result = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/files/retrieve",
|
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=MinimaxFileRetrieveResponse,
|
response_model=MinimaxFileRetrieveResponse,
|
||||||
query_params={"file_id": int(file_id)},
|
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
file_result = await file_retrieve_operation.execute()
|
|
||||||
|
|
||||||
file_url = file_result.file.download_url
|
file_url = file_result.file.download_url
|
||||||
if file_url is None:
|
if file_url is None:
|
||||||
raise Exception(
|
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
|
||||||
)
|
|
||||||
logging.info("Generated video URL: %s", file_url)
|
|
||||||
if cls.hidden.unique_id:
|
|
||||||
if hasattr(file_result.file, "backup_download_url"):
|
|
||||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
|
||||||
else:
|
|
||||||
message = f"Result URL: {file_url}"
|
|
||||||
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
|
|
||||||
|
|
||||||
video_io = await download_url_to_bytesio(file_url)
|
if file_result.file.backup_download_url:
|
||||||
if video_io is None:
|
try:
|
||||||
error_msg = f"Failed to download video from {file_url}"
|
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||||
logging.error(error_msg)
|
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||||
raise Exception(error_msg)
|
return IO.NodeOutput(
|
||||||
return IO.NodeOutput(VideoFromFile(video_io))
|
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||||
|
|
||||||
|
|
||||||
class MinimaxExtension(ComfyExtension):
|
class MinimaxExtension(ComfyExtension):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -78,7 +78,7 @@ class _PollUIState:
|
|||||||
|
|
||||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
|
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
|
||||||
FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"]
|
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -232,11 +232,12 @@ async def download_url_to_video_output(
|
|||||||
video_url: str,
|
video_url: str,
|
||||||
*,
|
*,
|
||||||
timeout: float = None,
|
timeout: float = None,
|
||||||
|
max_retries: int = 5,
|
||||||
cls: type[COMFY_IO.ComfyNode] = None,
|
cls: type[COMFY_IO.ComfyNode] = None,
|
||||||
) -> VideoFromFile:
|
) -> VideoFromFile:
|
||||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||||
result = BytesIO()
|
result = BytesIO()
|
||||||
await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls)
|
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||||
return VideoFromFile(result)
|
return VideoFromFile(result)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -445,6 +445,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
resolved_outputs.append(tuple(resolved_output))
|
resolved_outputs.append(tuple(resolved_output))
|
||||||
output_data = merge_result_data(resolved_outputs, class_def)
|
output_data = merge_result_data(resolved_outputs, class_def)
|
||||||
output_ui = []
|
output_ui = []
|
||||||
|
del pending_subgraph_results[unique_id]
|
||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
get_progress_state().start_progress(unique_id)
|
get_progress_state().start_progress(unique_id)
|
||||||
@ -527,10 +528,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
if new_graph is None:
|
if new_graph is None:
|
||||||
cached_outputs.append((False, node_outputs))
|
cached_outputs.append((False, node_outputs))
|
||||||
else:
|
else:
|
||||||
# Check for conflicts
|
|
||||||
for node_id in new_graph.keys():
|
|
||||||
if dynprompt.has_node(node_id):
|
|
||||||
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
|
|
||||||
for node_id, node_info in new_graph.items():
|
for node_id, node_info in new_graph.items():
|
||||||
new_node_ids.append(node_id)
|
new_node_ids.append(node_id)
|
||||||
display_id = node_info.get("override_display_id", unique_id)
|
display_id = node_info.get("override_display_id", unique_id)
|
||||||
|
|||||||
232
tests-unit/comfy_quant/test_mixed_precision.py
Normal file
232
tests-unit/comfy_quant/test_mixed_precision.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add comfy to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
def has_gpu():
|
||||||
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
if not has_gpu():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
from comfy import ops
|
||||||
|
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModel(torch.nn.Module):
|
||||||
|
def __init__(self, operations=ops.disable_weight_init):
|
||||||
|
super().__init__()
|
||||||
|
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
|
||||||
|
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
|
||||||
|
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = torch.nn.functional.relu(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = torch.nn.functional.relu(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TestMixedPrecisionOps(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_all_layers_standard(self):
|
||||||
|
"""Test that model with no quantization works normally"""
|
||||||
|
# Configure no quantization
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = {}
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
|
||||||
|
# Initialize weights manually
|
||||||
|
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
||||||
|
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
|
||||||
|
model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16))
|
||||||
|
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
|
||||||
|
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
|
||||||
|
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
|
||||||
|
|
||||||
|
# Initialize weight_function and bias_function
|
||||||
|
for layer in [model.layer1, model.layer2, model.layer3]:
|
||||||
|
layer.weight_function = []
|
||||||
|
layer.bias_function = []
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
self.assertEqual(output.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
def test_mixed_precision_load(self):
|
||||||
|
"""Test loading a mixed precision model from state dict"""
|
||||||
|
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
},
|
||||||
|
"layer3": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create state dict with mixed precision
|
||||||
|
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
# Layer 1: FP8 E4M3FN
|
||||||
|
"layer1.weight": fp8_weight1,
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||||
|
|
||||||
|
# Layer 2: Standard BF16
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
|
||||||
|
# Layer 3: FP8 E4M3FN
|
||||||
|
"layer3.weight": fp8_weight3,
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
|
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||||
|
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
# Layer 2 should NOT be quantized
|
||||||
|
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||||
|
|
||||||
|
# Layer 3 should be quantized
|
||||||
|
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||||
|
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
# Verify scales were loaded
|
||||||
|
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||||
|
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
|
def test_state_dict_quantized_preserved(self):
|
||||||
|
"""Test that quantized weights are preserved in state_dict()"""
|
||||||
|
# Configure mixed precision
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create and load model
|
||||||
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict1 = {
|
||||||
|
"layer1.weight": fp8_weight,
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32),
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model.load_state_dict(state_dict1, strict=False)
|
||||||
|
|
||||||
|
# Save state dict
|
||||||
|
state_dict2 = model.state_dict()
|
||||||
|
|
||||||
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
|
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||||
|
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
# Verify non-quantized layers are standard tensors
|
||||||
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
|
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
|
||||||
|
|
||||||
|
def test_weight_function_compatibility(self):
|
||||||
|
"""Test that weight_function (LoRA) works with quantized layers"""
|
||||||
|
# Configure FP8 quantization
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create and load model
|
||||||
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict = {
|
||||||
|
"layer1.weight": fp8_weight,
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
# Add a weight function (simulating LoRA)
|
||||||
|
# This should trigger dequantization during forward pass
|
||||||
|
def apply_lora(weight):
|
||||||
|
lora_delta = torch.randn_like(weight) * 0.01
|
||||||
|
return weight + lora_delta
|
||||||
|
|
||||||
|
model.layer1.weight_function.append(apply_lora)
|
||||||
|
|
||||||
|
# Forward pass should work with LoRA (triggers weight_function path)
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
|
def test_error_handling_unknown_format(self):
|
||||||
|
"""Test that unknown formats raise error"""
|
||||||
|
# Configure with unknown format
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "unknown_format_xyz",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create state dict
|
||||||
|
state_dict = {
|
||||||
|
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
||||||
190
tests-unit/comfy_quant/test_quant_registry.py
Normal file
190
tests-unit/comfy_quant/test_quant_registry.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add comfy to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
def has_gpu():
|
||||||
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
if not has_gpu():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuantizedTensor(unittest.TestCase):
|
||||||
|
"""Test the QuantizedTensor subclass with FP8 layout"""
|
||||||
|
|
||||||
|
def test_creation(self):
|
||||||
|
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||||
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(2.0)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||||
|
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
|
self.assertEqual(qt.shape, (256, 128))
|
||||||
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(qt._layout_params['scale'], scale)
|
||||||
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||||
|
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
def test_dequantize(self):
|
||||||
|
"""Test explicit dequantization"""
|
||||||
|
|
||||||
|
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(3.0)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
dequantized = qt.dequantize()
|
||||||
|
|
||||||
|
self.assertEqual(dequantized.dtype, torch.float32)
|
||||||
|
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||||
|
|
||||||
|
def test_from_float(self):
|
||||||
|
"""Test creating QuantizedTensor from float tensor"""
|
||||||
|
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
|
||||||
|
qt = QuantizedTensor.from_float(
|
||||||
|
float_tensor,
|
||||||
|
TensorCoreFP8Layout,
|
||||||
|
scale=scale,
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(qt.shape, (64, 32))
|
||||||
|
|
||||||
|
# Verify dequantization gives approximately original values
|
||||||
|
dequantized = qt.dequantize()
|
||||||
|
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
||||||
|
self.assertLess(mean_rel_error, 0.1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenericUtilities(unittest.TestCase):
|
||||||
|
"""Test generic utility operations"""
|
||||||
|
|
||||||
|
def test_detach(self):
|
||||||
|
"""Test detach operation on quantized tensor"""
|
||||||
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
|
# Detach should return a new QuantizedTensor
|
||||||
|
qt_detached = qt.detach()
|
||||||
|
|
||||||
|
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||||
|
self.assertEqual(qt_detached.shape, qt.shape)
|
||||||
|
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
def test_clone(self):
|
||||||
|
"""Test clone operation on quantized tensor"""
|
||||||
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
|
# Clone should return a new QuantizedTensor
|
||||||
|
qt_cloned = qt.clone()
|
||||||
|
|
||||||
|
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||||
|
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||||
|
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
# Verify it's a deep copy
|
||||||
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||||
|
|
||||||
|
@unittest.skipUnless(has_gpu(), "GPU not available")
|
||||||
|
def test_to_device(self):
|
||||||
|
"""Test device transfer"""
|
||||||
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
|
# Moving to same device should work (CPU to CPU)
|
||||||
|
qt_cpu = qt.to('cpu')
|
||||||
|
|
||||||
|
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
||||||
|
self.assertEqual(qt_cpu.device.type, 'cpu')
|
||||||
|
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorCoreFP8Layout(unittest.TestCase):
|
||||||
|
"""Test the TensorCoreFP8Layout implementation"""
|
||||||
|
|
||||||
|
def test_quantize(self):
|
||||||
|
"""Test quantization method"""
|
||||||
|
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
|
||||||
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||||
|
float_tensor,
|
||||||
|
scale=scale,
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(qdata.shape, float_tensor.shape)
|
||||||
|
self.assertIn('scale', layout_params)
|
||||||
|
self.assertIn('orig_dtype', layout_params)
|
||||||
|
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
||||||
|
|
||||||
|
def test_dequantize(self):
|
||||||
|
"""Test dequantization method"""
|
||||||
|
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
||||||
|
scale = torch.tensor(1.0)
|
||||||
|
|
||||||
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||||
|
float_tensor,
|
||||||
|
scale=scale,
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
||||||
|
|
||||||
|
# Should approximately match original
|
||||||
|
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackMechanism(unittest.TestCase):
|
||||||
|
"""Test fallback for unsupported operations"""
|
||||||
|
|
||||||
|
def test_unsupported_op_dequantizes(self):
|
||||||
|
"""Test that unsupported operations fall back to dequantization"""
|
||||||
|
# Set seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Create quantized tensor
|
||||||
|
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.0)
|
||||||
|
a_q = QuantizedTensor.from_float(
|
||||||
|
a_fp32,
|
||||||
|
TensorCoreFP8Layout,
|
||||||
|
scale=scale,
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call an operation that doesn't have a registered handler
|
||||||
|
# For example, torch.abs
|
||||||
|
result = torch.abs(a_q)
|
||||||
|
|
||||||
|
# Should work via fallback (dequantize → abs → return)
|
||||||
|
self.assertNotIsInstance(result, QuantizedTensor)
|
||||||
|
expected = torch.abs(a_fp32)
|
||||||
|
# FP8 introduces quantization error, so use loose tolerance
|
||||||
|
mean_error = (result - expected).abs().mean()
|
||||||
|
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue
Block a user