include aitemplate Model

This commit is contained in:
hlky 2023-05-15 19:15:50 +01:00
parent b32c2eaafd
commit eccba18c17
5 changed files with 1379 additions and 1 deletions

158
comfy/aitemplate/dtype.py Normal file
View File

@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
dtype definitions and utility functions of AITemplate
"""
_DTYPE2BYTE = {
"bool": 1,
"float16": 2,
"float32": 4,
"float": 4,
"int": 4,
"int32": 4,
"int64": 8,
"bfloat16": 2,
}
# Maps dtype strings to AITemplateDtype enum in model_interface.h.
# Must be kept in sync!
# We can consider defining an AITemplateDtype enum to use on the Python
# side at some point, but stick to strings for now to keep things consistent
# with other Python APIs.
_DTYPE_TO_ENUM = {
"float16": 1,
"float32": 2,
"float": 2,
"int": 3,
"int32": 3,
"int64": 4,
"bool": 5,
"bfloat16": 6,
}
def get_dtype_size(dtype: str) -> int:
"""Returns size (in bytes) of the given dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
int
Size (in bytes) of this dtype.
"""
if dtype not in _DTYPE2BYTE:
raise KeyError(f"Unknown dtype: {dtype}. Expected one of {_DTYPE2BYTE.keys()}")
return _DTYPE2BYTE[dtype]
def normalize_dtype(dtype: str) -> str:
"""Returns a normalized dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
str
normalized dtype str.
"""
if dtype == "int":
return "int32"
if dtype == "float":
return "float32"
return dtype
def dtype_str_to_enum(dtype: str) -> int:
"""Returns the AITemplateDtype enum value (defined in model_interface.h) of
the given dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
int
the AITemplateDtype enum value.
"""
if dtype not in _DTYPE_TO_ENUM:
raise ValueError(
f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPE_TO_ENUM.keys())}"
)
return _DTYPE_TO_ENUM[dtype]
def dtype_to_enumerator(dtype: str) -> str:
"""Returns the string representation of the AITemplateDtype enum
(defined in model_interface.h) for the given dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
str
the AITemplateDtype enum string representation.
"""
def _impl(dtype):
if dtype == "float16":
return "kHalf"
elif dtype == "float32" or dtype == "float":
return "kFloat"
elif dtype == "int32" or dtype == "int":
return "kInt"
elif dtype == "int64":
return "kLong"
elif dtype == "bool":
return "kBool"
elif dtype == "bfloat16":
return "kBFloat16"
else:
raise AssertionError(f"unknown dtype {dtype}")
return f"AITemplateDtype::{_impl(dtype)}"
def is_same_dtype(dtype1: str, dtype2: str) -> bool:
"""Returns True if dtype1 and dtype2 are the same dtype and False otherwise.
Parameters
----------
dtype1: str
A data type string.
dtype2: str
A data type string.
Returns
----------
bool
whether dtype1 and dtype2 are the same dtype
"""
return normalize_dtype(dtype1) == normalize_dtype(dtype2)

97
comfy/aitemplate/misc.py Normal file
View File

@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
miscellaneous utilities
"""
import hashlib
import logging
import os
import platform
def is_debug():
logger = logging.getLogger("aitemplate")
return logger.level == logging.DEBUG
def is_linux() -> bool:
return platform.system() == "Linux"
def is_windows() -> bool:
return os.name == "nt"
def setup_logger(name):
root_logger = logging.getLogger(name)
info_handle = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s %(levelname)s <%(name)s> %(message)s")
info_handle.setFormatter(formatter)
root_logger.addHandler(info_handle)
root_logger.propagate = False
DEFAULT_LOGLEVEL = logging.getLogger().level
log_level_str = os.environ.get("LOGLEVEL", None)
LOG_LEVEL = (
getattr(logging, log_level_str.upper())
if log_level_str is not None
else DEFAULT_LOGLEVEL
)
root_logger.setLevel(LOG_LEVEL)
return root_logger
def short_str(s, length=8) -> str:
"""
Returns a hashed string, somewhat similar to URL shortener.
"""
hash_str = hashlib.sha256(s.encode()).hexdigest()
return hash_str[0:length]
def callstack_stats(enable=False):
if enable:
def decorator(f):
import cProfile
import io
import pstats
logger = logging.getLogger(__name__)
def inner_function(*args, **kwargs):
pr = cProfile.Profile()
pr.enable()
result = f(*args, **kwargs)
pr.disable()
s = io.StringIO()
pstats.Stats(pr, stream=s).sort_stats(
pstats.SortKey.CUMULATIVE
).print_stats(30)
logger.debug(s.getvalue())
return result
return inner_function
return decorator
else:
def decorator(f):
def inner_function(*args, **kwargs):
return f(*args, **kwargs)
return inner_function
return decorator

1065
comfy/aitemplate/model.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Functions for working with torch Tensors.
AITemplate doesn't depend on PyTorch, but it exposes
many APIs that work with torch Tensors anyways.
The functions in this file may assume that
`import torch` will work.
"""
def types_mapping():
from torch import bfloat16, bool, float16, float32, int32, int64
yield (float16, "float16")
yield (bfloat16, "bfloat16")
yield (float32, "float32")
yield (int32, "int32")
yield (int64, "int64")
yield (bool, "bool")
def torch_dtype_to_string(dtype):
for (torch_dtype, ait_dtype) in types_mapping():
if dtype == torch_dtype:
return ait_dtype
raise ValueError(
f"Got unsupported input dtype {dtype}! "
f"Supported dtypes are: {list(types_mapping())}"
)
def string_to_torch_dtype(string_dtype):
if string_dtype is None:
# Many torch functions take optional dtypes, so
# handling None is useful here.
return None
for (torch_dtype, ait_dtype) in types_mapping():
if string_dtype == ait_dtype:
return torch_dtype
raise ValueError(
f"Got unsupported ait dtype {string_dtype}! "
f"Supported dtypes are: {list(types_mapping())}"
)

View File

@ -7,7 +7,7 @@ import hashlib
import traceback
import math
import time
from aitemplate.compiler import Model
from comfy.aitemplate.model import Model
from diffusers import LMSDiscreteScheduler
from PIL import Image
from PIL.PngImagePlugin import PngInfo