mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
159 lines
3.7 KiB
Python
159 lines
3.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""
|
|
dtype definitions and utility functions of AITemplate
|
|
"""
|
|
|
|
|
|
_DTYPE2BYTE = {
|
|
"bool": 1,
|
|
"float16": 2,
|
|
"float32": 4,
|
|
"float": 4,
|
|
"int": 4,
|
|
"int32": 4,
|
|
"int64": 8,
|
|
"bfloat16": 2,
|
|
}
|
|
|
|
|
|
# Maps dtype strings to AITemplateDtype enum in model_interface.h.
|
|
# Must be kept in sync!
|
|
# We can consider defining an AITemplateDtype enum to use on the Python
|
|
# side at some point, but stick to strings for now to keep things consistent
|
|
# with other Python APIs.
|
|
_DTYPE_TO_ENUM = {
|
|
"float16": 1,
|
|
"float32": 2,
|
|
"float": 2,
|
|
"int": 3,
|
|
"int32": 3,
|
|
"int64": 4,
|
|
"bool": 5,
|
|
"bfloat16": 6,
|
|
}
|
|
|
|
|
|
def get_dtype_size(dtype: str) -> int:
|
|
"""Returns size (in bytes) of the given dtype str.
|
|
|
|
Parameters
|
|
----------
|
|
dtype: str
|
|
A data type string.
|
|
|
|
Returns
|
|
----------
|
|
int
|
|
Size (in bytes) of this dtype.
|
|
"""
|
|
|
|
if dtype not in _DTYPE2BYTE:
|
|
raise KeyError(f"Unknown dtype: {dtype}. Expected one of {_DTYPE2BYTE.keys()}")
|
|
return _DTYPE2BYTE[dtype]
|
|
|
|
|
|
def normalize_dtype(dtype: str) -> str:
|
|
"""Returns a normalized dtype str.
|
|
|
|
Parameters
|
|
----------
|
|
dtype: str
|
|
A data type string.
|
|
|
|
Returns
|
|
----------
|
|
str
|
|
normalized dtype str.
|
|
"""
|
|
if dtype == "int":
|
|
return "int32"
|
|
if dtype == "float":
|
|
return "float32"
|
|
return dtype
|
|
|
|
|
|
def dtype_str_to_enum(dtype: str) -> int:
|
|
"""Returns the AITemplateDtype enum value (defined in model_interface.h) of
|
|
the given dtype str.
|
|
|
|
Parameters
|
|
----------
|
|
dtype: str
|
|
A data type string.
|
|
|
|
Returns
|
|
----------
|
|
int
|
|
the AITemplateDtype enum value.
|
|
"""
|
|
if dtype not in _DTYPE_TO_ENUM:
|
|
raise ValueError(
|
|
f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPE_TO_ENUM.keys())}"
|
|
)
|
|
return _DTYPE_TO_ENUM[dtype]
|
|
|
|
|
|
def dtype_to_enumerator(dtype: str) -> str:
|
|
"""Returns the string representation of the AITemplateDtype enum
|
|
(defined in model_interface.h) for the given dtype str.
|
|
|
|
Parameters
|
|
----------
|
|
dtype: str
|
|
A data type string.
|
|
|
|
Returns
|
|
----------
|
|
str
|
|
the AITemplateDtype enum string representation.
|
|
"""
|
|
|
|
def _impl(dtype):
|
|
if dtype == "float16":
|
|
return "kHalf"
|
|
elif dtype == "float32" or dtype == "float":
|
|
return "kFloat"
|
|
elif dtype == "int32" or dtype == "int":
|
|
return "kInt"
|
|
elif dtype == "int64":
|
|
return "kLong"
|
|
elif dtype == "bool":
|
|
return "kBool"
|
|
elif dtype == "bfloat16":
|
|
return "kBFloat16"
|
|
else:
|
|
raise AssertionError(f"unknown dtype {dtype}")
|
|
|
|
return f"AITemplateDtype::{_impl(dtype)}"
|
|
|
|
|
|
def is_same_dtype(dtype1: str, dtype2: str) -> bool:
|
|
"""Returns True if dtype1 and dtype2 are the same dtype and False otherwise.
|
|
|
|
Parameters
|
|
----------
|
|
dtype1: str
|
|
A data type string.
|
|
dtype2: str
|
|
A data type string.
|
|
|
|
Returns
|
|
----------
|
|
bool
|
|
whether dtype1 and dtype2 are the same dtype
|
|
"""
|
|
return normalize_dtype(dtype1) == normalize_dtype(dtype2)
|