include aitemplate Model

This commit is contained in:
hlky 2023-05-15 19:15:50 +01:00
parent b32c2eaafd
commit eccba18c17
5 changed files with 1379 additions and 1 deletions

158
comfy/aitemplate/dtype.py Normal file
View File

@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
dtype definitions and utility functions of AITemplate
"""
_DTYPE2BYTE = {
"bool": 1,
"float16": 2,
"float32": 4,
"float": 4,
"int": 4,
"int32": 4,
"int64": 8,
"bfloat16": 2,
}
# Maps dtype strings to AITemplateDtype enum in model_interface.h.
# Must be kept in sync!
# We can consider defining an AITemplateDtype enum to use on the Python
# side at some point, but stick to strings for now to keep things consistent
# with other Python APIs.
_DTYPE_TO_ENUM = {
"float16": 1,
"float32": 2,
"float": 2,
"int": 3,
"int32": 3,
"int64": 4,
"bool": 5,
"bfloat16": 6,
}
def get_dtype_size(dtype: str) -> int:
"""Returns size (in bytes) of the given dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
int
Size (in bytes) of this dtype.
"""
if dtype not in _DTYPE2BYTE:
raise KeyError(f"Unknown dtype: {dtype}. Expected one of {_DTYPE2BYTE.keys()}")
return _DTYPE2BYTE[dtype]
def normalize_dtype(dtype: str) -> str:
"""Returns a normalized dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
str
normalized dtype str.
"""
if dtype == "int":
return "int32"
if dtype == "float":
return "float32"
return dtype
def dtype_str_to_enum(dtype: str) -> int:
"""Returns the AITemplateDtype enum value (defined in model_interface.h) of
the given dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
int
the AITemplateDtype enum value.
"""
if dtype not in _DTYPE_TO_ENUM:
raise ValueError(
f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPE_TO_ENUM.keys())}"
)
return _DTYPE_TO_ENUM[dtype]
def dtype_to_enumerator(dtype: str) -> str:
"""Returns the string representation of the AITemplateDtype enum
(defined in model_interface.h) for the given dtype str.
Parameters
----------
dtype: str
A data type string.
Returns
----------
str
the AITemplateDtype enum string representation.
"""
def _impl(dtype):
if dtype == "float16":
return "kHalf"
elif dtype == "float32" or dtype == "float":
return "kFloat"
elif dtype == "int32" or dtype == "int":
return "kInt"
elif dtype == "int64":
return "kLong"
elif dtype == "bool":
return "kBool"
elif dtype == "bfloat16":
return "kBFloat16"
else:
raise AssertionError(f"unknown dtype {dtype}")
return f"AITemplateDtype::{_impl(dtype)}"
def is_same_dtype(dtype1: str, dtype2: str) -> bool:
"""Returns True if dtype1 and dtype2 are the same dtype and False otherwise.
Parameters
----------
dtype1: str
A data type string.
dtype2: str
A data type string.
Returns
----------
bool
whether dtype1 and dtype2 are the same dtype
"""
return normalize_dtype(dtype1) == normalize_dtype(dtype2)

97
comfy/aitemplate/misc.py Normal file
View File

@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
miscellaneous utilities
"""
import hashlib
import logging
import os
import platform
def is_debug():
logger = logging.getLogger("aitemplate")
return logger.level == logging.DEBUG
def is_linux() -> bool:
return platform.system() == "Linux"
def is_windows() -> bool:
return os.name == "nt"
def setup_logger(name):
root_logger = logging.getLogger(name)
info_handle = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s %(levelname)s <%(name)s> %(message)s")
info_handle.setFormatter(formatter)
root_logger.addHandler(info_handle)
root_logger.propagate = False
DEFAULT_LOGLEVEL = logging.getLogger().level
log_level_str = os.environ.get("LOGLEVEL", None)
LOG_LEVEL = (
getattr(logging, log_level_str.upper())
if log_level_str is not None
else DEFAULT_LOGLEVEL
)
root_logger.setLevel(LOG_LEVEL)
return root_logger
def short_str(s, length=8) -> str:
"""
Returns a hashed string, somewhat similar to URL shortener.
"""
hash_str = hashlib.sha256(s.encode()).hexdigest()
return hash_str[0:length]
def callstack_stats(enable=False):
if enable:
def decorator(f):
import cProfile
import io
import pstats
logger = logging.getLogger(__name__)
def inner_function(*args, **kwargs):
pr = cProfile.Profile()
pr.enable()
result = f(*args, **kwargs)
pr.disable()
s = io.StringIO()
pstats.Stats(pr, stream=s).sort_stats(
pstats.SortKey.CUMULATIVE
).print_stats(30)
logger.debug(s.getvalue())
return result
return inner_function
return decorator
else:
def decorator(f):
def inner_function(*args, **kwargs):
return f(*args, **kwargs)
return inner_function
return decorator

1065
comfy/aitemplate/model.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Functions for working with torch Tensors.
AITemplate doesn't depend on PyTorch, but it exposes
many APIs that work with torch Tensors anyways.
The functions in this file may assume that
`import torch` will work.
"""
def types_mapping():
from torch import bfloat16, bool, float16, float32, int32, int64
yield (float16, "float16")
yield (bfloat16, "bfloat16")
yield (float32, "float32")
yield (int32, "int32")
yield (int64, "int64")
yield (bool, "bool")
def torch_dtype_to_string(dtype):
for (torch_dtype, ait_dtype) in types_mapping():
if dtype == torch_dtype:
return ait_dtype
raise ValueError(
f"Got unsupported input dtype {dtype}! "
f"Supported dtypes are: {list(types_mapping())}"
)
def string_to_torch_dtype(string_dtype):
if string_dtype is None:
# Many torch functions take optional dtypes, so
# handling None is useful here.
return None
for (torch_dtype, ait_dtype) in types_mapping():
if string_dtype == ait_dtype:
return torch_dtype
raise ValueError(
f"Got unsupported ait dtype {string_dtype}! "
f"Supported dtypes are: {list(types_mapping())}"
)

View File

@ -7,7 +7,7 @@ import hashlib
import traceback import traceback
import math import math
import time import time
from aitemplate.compiler import Model from comfy.aitemplate.model import Model
from diffusers import LMSDiscreteScheduler from diffusers import LMSDiscreteScheduler
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo