mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 18:43:05 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
d8b821e47b
@ -153,7 +153,7 @@ class PerformanceFeature(enum.Enum):
|
|||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
PinnedMem = "pinned_memory"
|
PinnedMem = "pinned_memory"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||||
|
|||||||
@ -298,6 +298,7 @@ class ModelPatcher:
|
|||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
n.parent = self
|
n.parent = self
|
||||||
|
n.pinned = self.pinned
|
||||||
|
|
||||||
n.force_cast_weights = self.force_cast_weights
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
|
|||||||
47
comfy/ops.py
47
comfy/ops.py
@ -84,7 +84,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
if offloadable:
|
if offloadable and (device != s.weight.device or
|
||||||
|
(s.bias is not None and device != s.bias.device)):
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
else:
|
else:
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
@ -94,20 +95,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
else:
|
else:
|
||||||
wf_context = contextlib.nullcontext()
|
wf_context = contextlib.nullcontext()
|
||||||
|
|
||||||
bias = None
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
|
||||||
has_function = len(s.bias_function) > 0
|
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
|
||||||
|
|
||||||
if has_function:
|
weight_has_function = len(s.weight_function) > 0
|
||||||
|
bias_has_function = len(s.bias_function) > 0
|
||||||
|
|
||||||
|
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
bias = None
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
if bias_has_function:
|
||||||
with wf_context:
|
with wf_context:
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
has_function = len(s.weight_function) > 0
|
weight = weight.to(dtype=dtype)
|
||||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
if weight_has_function:
|
||||||
if has_function:
|
|
||||||
with wf_context:
|
with wf_context:
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
@ -401,15 +406,9 @@ def fp8_linear(self, input):
|
|||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
tensor_2d = False
|
|
||||||
if len(input.shape) == 2:
|
|
||||||
tensor_2d = True
|
|
||||||
input = input.unsqueeze(1)
|
|
||||||
|
|
||||||
input_shape = input.shape
|
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if input.ndim == 3 or input.ndim == 2:
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
@ -422,24 +421,20 @@ def fp8_linear(self, input):
|
|||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
|
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
uncast_bias_weight(self, w, bias, offload_stream)
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
|
return o
|
||||||
if tensor_2d:
|
|
||||||
return o.reshape(input_shape[0], -1)
|
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -540,12 +535,12 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from .quant_ops import QuantizedTensor
|
||||||
|
|
||||||
QUANT_FORMAT_MIXINS = {
|
QUANT_FORMAT_MIXINS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
"dtype": torch.float8_e4m3fn,
|
"dtype": torch.float8_e4m3fn,
|
||||||
"layout_type": TensorCoreFP8Layout,
|
"layout_type": "TensorCoreFP8Layout",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
|||||||
@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
layout_type: Layout class (subclass of QuantizedLayout)
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
layout_params: Dict with layout-specific parameters
|
layout_params: Dict with layout-specific parameters
|
||||||
"""
|
"""
|
||||||
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||||
|
|
||||||
def __init__(self, qdata, layout_type, layout_params):
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
self._qdata = qdata.contiguous()
|
self._qdata = qdata.contiguous()
|
||||||
@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||||
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||||
return cls(qdata, layout_type, layout_params)
|
return cls(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
def dequantize(self) -> torch.Tensor:
|
def dequantize(self) -> torch.Tensor:
|
||||||
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
return qtensor._qdata, qtensor._layout_params['scale']
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
LAYOUTS = {
|
||||||
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||||
def fp8_linear(func, args, kwargs):
|
def fp8_linear(func, args, kwargs):
|
||||||
input_tensor = args[0]
|
input_tensor = args[0]
|
||||||
weight = args[1]
|
weight = args[1]
|
||||||
@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs):
|
|||||||
'scale': output_scale,
|
'scale': output_scale,
|
||||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||||
}
|
}
|
||||||
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs):
|
|||||||
input_tensor = input_tensor.dequantize()
|
input_tensor = input_tensor.dequantize()
|
||||||
|
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||||
|
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_func(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
ar = list(args)
|
||||||
|
ar[0] = plain_input
|
||||||
|
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class TextToVideoNode(IO.ComfyNode):
|
|||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[
|
options=[
|
||||||
@ -85,6 +85,10 @@ class TextToVideoNode(IO.ComfyNode):
|
|||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1, max_length=10000)
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||||
|
raise ValueError(
|
||||||
|
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||||
|
)
|
||||||
response = await sync_op_raw(
|
response = await sync_op_raw(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
||||||
@ -118,7 +122,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[
|
options=[
|
||||||
@ -158,6 +162,10 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1, max_length=10000)
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||||
|
raise ValueError(
|
||||||
|
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||||
|
)
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Currently only one input image is supported.")
|
raise ValueError("Currently only one input image is supported.")
|
||||||
response = await sync_op_raw(
|
response = await sync_op_raw(
|
||||||
|
|||||||
@ -20,13 +20,6 @@ from comfy_api_nodes.apis.stability_api import (
|
|||||||
StabilityAudioInpaintRequest,
|
StabilityAudioInpaintRequest,
|
||||||
StabilityAudioResponse,
|
StabilityAudioResponse,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
validate_audio_duration,
|
validate_audio_duration,
|
||||||
validate_string,
|
validate_string,
|
||||||
@ -34,6 +27,9 @@ from comfy_api_nodes.util import (
|
|||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
audio_bytes_to_audio_input,
|
audio_bytes_to_audio_input,
|
||||||
|
sync_op,
|
||||||
|
poll_op,
|
||||||
|
ApiEndpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
|
||||||
}
|
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityStableUltraRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
response_model=StabilityStableUltraResponse,
|
||||||
),
|
data=StabilityStableUltraRequest(
|
||||||
request=StabilityStableUltraRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||||
@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
|
||||||
}
|
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityStable3_5Request,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
response_model=StabilityStableUltraResponse,
|
||||||
),
|
data=StabilityStable3_5Request(
|
||||||
request=StabilityStable3_5Request(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||||
@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
|
||||||
}
|
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityUpscaleConservativeRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
response_model=StabilityStableUltraResponse,
|
||||||
),
|
data=StabilityUpscaleConservativeRequest(
|
||||||
request=StabilityUpscaleConservativeRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
creativity=round(creativity,2),
|
creativity=round(creativity,2),
|
||||||
@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||||
@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
|
||||||
}
|
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityUpscaleCreativeRequest,
|
|
||||||
response_model=StabilityAsyncResponse,
|
response_model=StabilityAsyncResponse,
|
||||||
),
|
data=StabilityUpscaleCreativeRequest(
|
||||||
request=StabilityUpscaleCreativeRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
creativity=round(creativity,2),
|
creativity=round(creativity,2),
|
||||||
@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/stability/v2beta/results/{response_api.id}",
|
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=StabilityResultsGetResponse,
|
response_model=StabilityResultsGetResponse,
|
||||||
),
|
|
||||||
poll_interval=3,
|
poll_interval=3,
|
||||||
completed_statuses=[StabilityPollStatus.finished],
|
|
||||||
failed_statuses=[StabilityPollStatus.failed],
|
|
||||||
status_extractor=lambda x: get_async_dummy_status(x),
|
status_extractor=lambda x: get_async_dummy_status(x),
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
)
|
)
|
||||||
response_poll: StabilityResultsGetResponse = await operation.execute()
|
|
||||||
|
|
||||||
if response_poll.finish_reason != "SUCCESS":
|
if response_poll.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||||
@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
|
||||||
}
|
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
response_model=StabilityStableUltraResponse,
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||||
@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode):
|
|||||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
||||||
validate_string(prompt, max_length=10000)
|
validate_string(prompt, max_length=10000)
|
||||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
|
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityTextToAudioRequest,
|
|
||||||
response_model=StabilityAudioResponse,
|
response_model=StabilityAudioResponse,
|
||||||
),
|
data=payload,
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs= {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode):
|
|||||||
payload = StabilityAudioToAudioRequest(
|
payload = StabilityAudioToAudioRequest(
|
||||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||||
)
|
)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
|
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityAudioToAudioRequest,
|
|
||||||
response_model=StabilityAudioResponse,
|
response_model=StabilityAudioResponse,
|
||||||
),
|
data=payload,
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
files={"audio": audio_input_to_mp3(audio)},
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
auth_kwargs= {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode):
|
|||||||
mask_start=mask_start,
|
mask_start=mask_start,
|
||||||
mask_end=mask_end,
|
mask_end=mask_end,
|
||||||
)
|
)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
|
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityAudioInpaintRequest,
|
|
||||||
response_model=StabilityAudioResponse,
|
response_model=StabilityAudioResponse,
|
||||||
),
|
data=payload,
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
files={"audio": audio_input_to_mp3(audio)},
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
auth_kwargs={
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class _PollUIState:
|
|||||||
|
|
||||||
|
|
||||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
|
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"]
|
||||||
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||||
|
|
||||||
@ -589,7 +589,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||||
|
|
||||||
payload_headers = {"Accept": "*/*"}
|
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||||
if cfg.endpoint.headers:
|
if cfg.endpoint.headers:
|
||||||
|
|||||||
@ -14,7 +14,7 @@ if not has_gpu():
|
|||||||
args.cpu = True
|
args.cpu = True
|
||||||
|
|
||||||
from comfy import ops
|
from comfy import ops
|
||||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
|
||||||
class SimpleModel(torch.nn.Module):
|
class SimpleModel(torch.nn.Module):
|
||||||
@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Layer 2 should NOT be quantized
|
# Layer 2 should NOT be quantized
|
||||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||||
|
|
||||||
# Layer 3 should be quantized
|
# Layer 3 should be quantized
|
||||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify scales were loaded
|
# Verify scales were loaded
|
||||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||||
@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify non-quantized layers are standard tensors
|
# Verify non-quantized layers are standard tensors
|
||||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
|
|||||||
@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
scale = torch.tensor(2.0)
|
scale = torch.tensor(2.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
self.assertEqual(qt.shape, (256, 128))
|
self.assertEqual(qt.shape, (256, 128))
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
self.assertEqual(qt._layout_params['scale'], scale)
|
self.assertEqual(qt._layout_params['scale'], scale)
|
||||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||||
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
def test_dequantize(self):
|
def test_dequantize(self):
|
||||||
"""Test explicit dequantization"""
|
"""Test explicit dequantization"""
|
||||||
@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
scale = torch.tensor(3.0)
|
scale = torch.tensor(3.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
dequantized = qt.dequantize()
|
dequantized = qt.dequantize()
|
||||||
|
|
||||||
self.assertEqual(dequantized.dtype, torch.float32)
|
self.assertEqual(dequantized.dtype, torch.float32)
|
||||||
@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
|
|
||||||
qt = QuantizedTensor.from_float(
|
qt = QuantizedTensor.from_float(
|
||||||
float_tensor,
|
float_tensor,
|
||||||
TensorCoreFP8Layout,
|
"TensorCoreFP8Layout",
|
||||||
scale=scale,
|
scale=scale,
|
||||||
dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Detach should return a new QuantizedTensor
|
# Detach should return a new QuantizedTensor
|
||||||
qt_detached = qt.detach()
|
qt_detached = qt.detach()
|
||||||
|
|
||||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||||
self.assertEqual(qt_detached.shape, qt.shape)
|
self.assertEqual(qt_detached.shape, qt.shape)
|
||||||
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
"""Test clone operation on quantized tensor"""
|
"""Test clone operation on quantized tensor"""
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Clone should return a new QuantizedTensor
|
# Clone should return a new QuantizedTensor
|
||||||
qt_cloned = qt.clone()
|
qt_cloned = qt.clone()
|
||||||
|
|
||||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||||
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify it's a deep copy
|
# Verify it's a deep copy
|
||||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||||
@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Moving to same device should work (CPU to CPU)
|
# Moving to same device should work (CPU to CPU)
|
||||||
qt_cpu = qt.to('cpu')
|
qt_cpu = qt.to('cpu')
|
||||||
@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase):
|
|||||||
scale = torch.tensor(1.0)
|
scale = torch.tensor(1.0)
|
||||||
a_q = QuantizedTensor.from_float(
|
a_q = QuantizedTensor.from_float(
|
||||||
a_fp32,
|
a_fp32,
|
||||||
TensorCoreFP8Layout,
|
"TensorCoreFP8Layout",
|
||||||
scale=scale,
|
scale=scale,
|
||||||
dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user