mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 08:52:34 +08:00
include aitemplate Model
This commit is contained in:
parent
b32c2eaafd
commit
eccba18c17
158
comfy/aitemplate/dtype.py
Normal file
158
comfy/aitemplate/dtype.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
dtype definitions and utility functions of AITemplate
|
||||
"""
|
||||
|
||||
|
||||
_DTYPE2BYTE = {
|
||||
"bool": 1,
|
||||
"float16": 2,
|
||||
"float32": 4,
|
||||
"float": 4,
|
||||
"int": 4,
|
||||
"int32": 4,
|
||||
"int64": 8,
|
||||
"bfloat16": 2,
|
||||
}
|
||||
|
||||
|
||||
# Maps dtype strings to AITemplateDtype enum in model_interface.h.
|
||||
# Must be kept in sync!
|
||||
# We can consider defining an AITemplateDtype enum to use on the Python
|
||||
# side at some point, but stick to strings for now to keep things consistent
|
||||
# with other Python APIs.
|
||||
_DTYPE_TO_ENUM = {
|
||||
"float16": 1,
|
||||
"float32": 2,
|
||||
"float": 2,
|
||||
"int": 3,
|
||||
"int32": 3,
|
||||
"int64": 4,
|
||||
"bool": 5,
|
||||
"bfloat16": 6,
|
||||
}
|
||||
|
||||
|
||||
def get_dtype_size(dtype: str) -> int:
|
||||
"""Returns size (in bytes) of the given dtype str.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dtype: str
|
||||
A data type string.
|
||||
|
||||
Returns
|
||||
----------
|
||||
int
|
||||
Size (in bytes) of this dtype.
|
||||
"""
|
||||
|
||||
if dtype not in _DTYPE2BYTE:
|
||||
raise KeyError(f"Unknown dtype: {dtype}. Expected one of {_DTYPE2BYTE.keys()}")
|
||||
return _DTYPE2BYTE[dtype]
|
||||
|
||||
|
||||
def normalize_dtype(dtype: str) -> str:
|
||||
"""Returns a normalized dtype str.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dtype: str
|
||||
A data type string.
|
||||
|
||||
Returns
|
||||
----------
|
||||
str
|
||||
normalized dtype str.
|
||||
"""
|
||||
if dtype == "int":
|
||||
return "int32"
|
||||
if dtype == "float":
|
||||
return "float32"
|
||||
return dtype
|
||||
|
||||
|
||||
def dtype_str_to_enum(dtype: str) -> int:
|
||||
"""Returns the AITemplateDtype enum value (defined in model_interface.h) of
|
||||
the given dtype str.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dtype: str
|
||||
A data type string.
|
||||
|
||||
Returns
|
||||
----------
|
||||
int
|
||||
the AITemplateDtype enum value.
|
||||
"""
|
||||
if dtype not in _DTYPE_TO_ENUM:
|
||||
raise ValueError(
|
||||
f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPE_TO_ENUM.keys())}"
|
||||
)
|
||||
return _DTYPE_TO_ENUM[dtype]
|
||||
|
||||
|
||||
def dtype_to_enumerator(dtype: str) -> str:
|
||||
"""Returns the string representation of the AITemplateDtype enum
|
||||
(defined in model_interface.h) for the given dtype str.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dtype: str
|
||||
A data type string.
|
||||
|
||||
Returns
|
||||
----------
|
||||
str
|
||||
the AITemplateDtype enum string representation.
|
||||
"""
|
||||
|
||||
def _impl(dtype):
|
||||
if dtype == "float16":
|
||||
return "kHalf"
|
||||
elif dtype == "float32" or dtype == "float":
|
||||
return "kFloat"
|
||||
elif dtype == "int32" or dtype == "int":
|
||||
return "kInt"
|
||||
elif dtype == "int64":
|
||||
return "kLong"
|
||||
elif dtype == "bool":
|
||||
return "kBool"
|
||||
elif dtype == "bfloat16":
|
||||
return "kBFloat16"
|
||||
else:
|
||||
raise AssertionError(f"unknown dtype {dtype}")
|
||||
|
||||
return f"AITemplateDtype::{_impl(dtype)}"
|
||||
|
||||
|
||||
def is_same_dtype(dtype1: str, dtype2: str) -> bool:
|
||||
"""Returns True if dtype1 and dtype2 are the same dtype and False otherwise.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dtype1: str
|
||||
A data type string.
|
||||
dtype2: str
|
||||
A data type string.
|
||||
|
||||
Returns
|
||||
----------
|
||||
bool
|
||||
whether dtype1 and dtype2 are the same dtype
|
||||
"""
|
||||
return normalize_dtype(dtype1) == normalize_dtype(dtype2)
|
||||
97
comfy/aitemplate/misc.py
Normal file
97
comfy/aitemplate/misc.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
miscellaneous utilities
|
||||
"""
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
|
||||
|
||||
def is_debug():
|
||||
logger = logging.getLogger("aitemplate")
|
||||
return logger.level == logging.DEBUG
|
||||
|
||||
|
||||
def is_linux() -> bool:
|
||||
return platform.system() == "Linux"
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
return os.name == "nt"
|
||||
|
||||
|
||||
def setup_logger(name):
|
||||
root_logger = logging.getLogger(name)
|
||||
info_handle = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s %(levelname)s <%(name)s> %(message)s")
|
||||
info_handle.setFormatter(formatter)
|
||||
root_logger.addHandler(info_handle)
|
||||
root_logger.propagate = False
|
||||
|
||||
DEFAULT_LOGLEVEL = logging.getLogger().level
|
||||
log_level_str = os.environ.get("LOGLEVEL", None)
|
||||
LOG_LEVEL = (
|
||||
getattr(logging, log_level_str.upper())
|
||||
if log_level_str is not None
|
||||
else DEFAULT_LOGLEVEL
|
||||
)
|
||||
root_logger.setLevel(LOG_LEVEL)
|
||||
return root_logger
|
||||
|
||||
|
||||
def short_str(s, length=8) -> str:
|
||||
"""
|
||||
Returns a hashed string, somewhat similar to URL shortener.
|
||||
"""
|
||||
hash_str = hashlib.sha256(s.encode()).hexdigest()
|
||||
return hash_str[0:length]
|
||||
|
||||
|
||||
def callstack_stats(enable=False):
|
||||
if enable:
|
||||
|
||||
def decorator(f):
|
||||
import cProfile
|
||||
import io
|
||||
import pstats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def inner_function(*args, **kwargs):
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
result = f(*args, **kwargs)
|
||||
pr.disable()
|
||||
s = io.StringIO()
|
||||
pstats.Stats(pr, stream=s).sort_stats(
|
||||
pstats.SortKey.CUMULATIVE
|
||||
).print_stats(30)
|
||||
logger.debug(s.getvalue())
|
||||
return result
|
||||
|
||||
return inner_function
|
||||
|
||||
return decorator
|
||||
else:
|
||||
|
||||
def decorator(f):
|
||||
def inner_function(*args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return inner_function
|
||||
|
||||
return decorator
|
||||
1065
comfy/aitemplate/model.py
Normal file
1065
comfy/aitemplate/model.py
Normal file
File diff suppressed because it is too large
Load Diff
58
comfy/aitemplate/torch_utils.py
Normal file
58
comfy/aitemplate/torch_utils.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Functions for working with torch Tensors.
|
||||
AITemplate doesn't depend on PyTorch, but it exposes
|
||||
many APIs that work with torch Tensors anyways.
|
||||
|
||||
The functions in this file may assume that
|
||||
`import torch` will work.
|
||||
"""
|
||||
|
||||
|
||||
def types_mapping():
|
||||
from torch import bfloat16, bool, float16, float32, int32, int64
|
||||
|
||||
yield (float16, "float16")
|
||||
yield (bfloat16, "bfloat16")
|
||||
yield (float32, "float32")
|
||||
yield (int32, "int32")
|
||||
yield (int64, "int64")
|
||||
yield (bool, "bool")
|
||||
|
||||
|
||||
def torch_dtype_to_string(dtype):
|
||||
for (torch_dtype, ait_dtype) in types_mapping():
|
||||
if dtype == torch_dtype:
|
||||
return ait_dtype
|
||||
raise ValueError(
|
||||
f"Got unsupported input dtype {dtype}! "
|
||||
f"Supported dtypes are: {list(types_mapping())}"
|
||||
)
|
||||
|
||||
|
||||
def string_to_torch_dtype(string_dtype):
|
||||
if string_dtype is None:
|
||||
# Many torch functions take optional dtypes, so
|
||||
# handling None is useful here.
|
||||
return None
|
||||
|
||||
for (torch_dtype, ait_dtype) in types_mapping():
|
||||
if string_dtype == ait_dtype:
|
||||
return torch_dtype
|
||||
raise ValueError(
|
||||
f"Got unsupported ait dtype {string_dtype}! "
|
||||
f"Supported dtypes are: {list(types_mapping())}"
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user