mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
Improvements for Wan 2.2 support
- add xet support and add the xet cache to manageable directories - xet is enabled by default - fix logging to root in various places - improve logging about model unloading and loading - TorchCompileNode now supports the VAE - torchaudio missing will cause less noise in the logs - feature flags will assume to be supporting everything in the distributed progress context - fixes progress notifications
This commit is contained in:
parent
b5a50301f6
commit
03e5430121
@ -84,7 +84,7 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
|||||||
|
|
||||||
stream_handler = logging.StreamHandler()
|
stream_handler = logging.StreamHandler()
|
||||||
stream_handler.setFormatter(logging.Formatter(
|
stream_handler.setFormatter(logging.Formatter(
|
||||||
"%(asctime)s [%(name)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
"%(asctime)s [%(levelname)s] [%(name)s] [%(filename)s:%(lineno)d] %(message)s",
|
||||||
datefmt="%Y-%m-%d %H:%M:%S"
|
datefmt="%Y-%m-%d %H:%M:%S"
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|||||||
@ -86,6 +86,9 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
|
|||||||
if "HF_HUB_CACHE" in os.environ:
|
if "HF_HUB_CACHE" in os.environ:
|
||||||
hf_cache_paths.additional_absolute_directory_paths.append(os.environ.get("HF_HUB_CACHE"))
|
hf_cache_paths.additional_absolute_directory_paths.append(os.environ.get("HF_HUB_CACHE"))
|
||||||
|
|
||||||
|
hf_xet = ModelPaths(["xet"], supported_extensions=set())
|
||||||
|
if "HF_XET_CACHE" in os.environ:
|
||||||
|
hf_xet.additional_absolute_directory_paths.append(os.environ.get("HF_XET_CACHE"))
|
||||||
model_paths_to_add = [
|
model_paths_to_add = [
|
||||||
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
|
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
|
||||||
ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}),
|
ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}),
|
||||||
@ -107,6 +110,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
|
|||||||
ModelPaths(["classifiers"], supported_extensions=set()),
|
ModelPaths(["classifiers"], supported_extensions=set()),
|
||||||
ModelPaths(["huggingface"], supported_extensions=set()),
|
ModelPaths(["huggingface"], supported_extensions=set()),
|
||||||
hf_cache_paths,
|
hf_cache_paths,
|
||||||
|
hf_xet,
|
||||||
]
|
]
|
||||||
for model_paths in model_paths_to_add:
|
for model_paths in model_paths_to_add:
|
||||||
if replace_existing:
|
if replace_existing:
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from .. import options
|
|||||||
from ..app import logger
|
from ..app import logger
|
||||||
|
|
||||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
||||||
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
||||||
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
||||||
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
||||||
|
|||||||
@ -221,7 +221,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
handler_args={'max_field_size': 16380},
|
handler_args={'max_field_size': 16380},
|
||||||
middlewares=middlewares)
|
middlewares=middlewares)
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
self.sockets_metadata = dict()
|
self._sockets_metadata = dict()
|
||||||
self.web_root = (
|
self.web_root = (
|
||||||
FrontendManager.init_frontend(args.front_end_version)
|
FrontendManager.init_frontend(args.front_end_version)
|
||||||
if args.front_end_root is None
|
if args.front_end_root is None
|
||||||
@ -278,16 +278,16 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
sid,
|
sid,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(
|
logger.info(
|
||||||
f"Feature flags negotiated for client {sid}: {client_flags}"
|
f"Feature flags negotiated for client {sid}: {client_flags}"
|
||||||
)
|
)
|
||||||
first_message = False
|
first_message = False
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logging.warning(
|
logger.warning(
|
||||||
f"Invalid JSON received from client {sid}: {msg.data}"
|
f"Invalid JSON received from client {sid}: {msg.data}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error processing WebSocket message: {e}")
|
logger.error(f"Error processing WebSocket message: {e}")
|
||||||
finally:
|
finally:
|
||||||
self.sockets.pop(sid, None)
|
self.sockets.pop(sid, None)
|
||||||
self.sockets_metadata.pop(sid, None)
|
self.sockets_metadata.pop(sid, None)
|
||||||
@ -1236,3 +1236,11 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
message = encode_text_for_progress(node_id, text)
|
message = encode_text_for_progress(node_id, text)
|
||||||
|
|
||||||
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sockets_metadata(self):
|
||||||
|
return self._sockets_metadata
|
||||||
|
|
||||||
|
@sockets_metadata.setter
|
||||||
|
def sockets_metadata(self, value):
|
||||||
|
self._sockets_metadata = value
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations # for Python 3.7-3.9
|
from __future__ import annotations # for Python 3.7-3.9
|
||||||
|
|
||||||
import PIL.Image
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import typing
|
import typing
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
from .encode_text_for_progress import encode_text_for_progress
|
from .encode_text_for_progress import encode_text_for_progress
|
||||||
@ -79,11 +80,39 @@ class DependencyExecutionErrorMessage(TypedDict):
|
|||||||
current_inputs: list[typing.Never]
|
current_inputs: list[typing.Never]
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveNodeProgressState(TypedDict, total=True):
|
||||||
|
value: float
|
||||||
|
max: float
|
||||||
|
# a string value from the NodeState enum
|
||||||
|
state: Literal["pending", "running", "finished", "error"]
|
||||||
|
node_id: str
|
||||||
|
prompt_id: str
|
||||||
|
display_node_id: str
|
||||||
|
parent_node_id: str
|
||||||
|
real_node_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressStateMessage(TypedDict, total=True):
|
||||||
|
prompt_id: str
|
||||||
|
nodes: dict[str, ActiveNodeProgressState]
|
||||||
|
|
||||||
|
|
||||||
ExecutedMessage = ExecutingMessage
|
ExecutedMessage = ExecutingMessage
|
||||||
|
|
||||||
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
|
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed", "progress_state"], BinaryEventTypes, None]
|
||||||
|
|
||||||
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[UnencodedPreviewImageMessage, PreviewImageMetadata], bytes, bytearray, str, None]
|
SendSyncData = Union[ProgressStateMessage, StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[PIL.Image.Image, PreviewImageMetadata], bytes, bytearray, str, None]
|
||||||
|
|
||||||
|
|
||||||
|
class SocketsMetadata(TypedDict, total=True):
|
||||||
|
feature_flags: dict[str, typing.Any]
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultSocketsMetadata(TypedDict, total=True):
|
||||||
|
__unimplemented: Literal[True]
|
||||||
|
|
||||||
|
|
||||||
|
SocketsMetadataType = dict[str, SocketsMetadata] | DefaultSocketsMetadata
|
||||||
|
|
||||||
|
|
||||||
class ExecutorToClientProgress(Protocol):
|
class ExecutorToClientProgress(Protocol):
|
||||||
@ -108,6 +137,10 @@ class ExecutorToClientProgress(Protocol):
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sockets_metadata(self) -> SocketsMetadataType:
|
||||||
|
return {"__unimplemented": True}
|
||||||
|
|
||||||
def send_sync(self,
|
def send_sync(self,
|
||||||
event: SendSyncEvent,
|
event: SendSyncEvent,
|
||||||
data: SendSyncData,
|
data: SendSyncData,
|
||||||
|
|||||||
@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
|||||||
try:
|
try:
|
||||||
import torchaudio # pylint: disable=import-error
|
import torchaudio # pylint: disable=import-error
|
||||||
except:
|
except:
|
||||||
logger.warning("torchaudio missing, ACE model will be broken")
|
logger.debug("torchaudio missing, ACE model will be broken")
|
||||||
|
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from .music_vocoder import ADaMoSHiFiGANV1
|
from .music_vocoder import ADaMoSHiFiGANV1
|
||||||
|
|||||||
@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
|
|||||||
try:
|
try:
|
||||||
from torchaudio.transforms import MelScale # pylint: disable=import-error
|
from torchaudio.transforms import MelScale # pylint: disable=import-error
|
||||||
except:
|
except:
|
||||||
logger.warning("torchaudio missing, ACE model will be broken")
|
logger.debug("torchaudio missing, ACE model will be broken")
|
||||||
|
|
||||||
from .... import model_management
|
from .... import model_management
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,8 @@ from .... import ops
|
|||||||
|
|
||||||
ops = ops.disable_weight_init
|
ops = ops.disable_weight_init
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
import xformers # pylint: disable=import-error
|
import xformers # pylint: disable=import-error
|
||||||
import xformers.ops # pylint: disable=import-error
|
import xformers.ops # pylint: disable=import-error
|
||||||
@ -242,7 +244,7 @@ def slice_attention(q, k, v):
|
|||||||
steps *= 2
|
steps *= 2
|
||||||
if steps > 128:
|
if steps > 128:
|
||||||
raise e
|
raise e
|
||||||
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
logger.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
||||||
|
|
||||||
return r1
|
return r1
|
||||||
|
|
||||||
@ -296,20 +298,20 @@ def pytorch_attention(q, k, v):
|
|||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
out = out.transpose(2, 3).reshape(orig_shape)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logger.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def vae_attention():
|
def vae_attention():
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
logging.debug("Using xformers attention in VAE")
|
logger.debug("Using xformers attention in VAE")
|
||||||
return xformers_attention
|
return xformers_attention
|
||||||
elif model_management.pytorch_attention_enabled_vae():
|
elif model_management.pytorch_attention_enabled_vae():
|
||||||
logging.debug("Using pytorch attention in VAE")
|
logger.debug("Using pytorch attention in VAE")
|
||||||
return pytorch_attention
|
return pytorch_attention
|
||||||
else:
|
else:
|
||||||
logging.debug("Using split attention in VAE")
|
logger.debug("Using split attention in VAE")
|
||||||
return normal_attention
|
return normal_attention
|
||||||
|
|
||||||
|
|
||||||
@ -650,7 +652,7 @@ class Decoder(nn.Module):
|
|||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
logging.debug("Working with z of shape {} = {} dimensions.".format(
|
logger.debug("Working with z of shape {} = {} dimensions.".format(
|
||||||
self.z_shape, np.prod(self.z_shape)))
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
|
|||||||
@ -16,13 +16,13 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TypeVar, Type, Protocol, Any, Optional
|
from typing import TypeVar, Type, Protocol, Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from . import conds
|
from . import conds
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
from . import model_management
|
from . import model_management
|
||||||
@ -50,9 +50,9 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
|||||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
|
from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel
|
||||||
from .ldm.pixart.pixartms import PixArtMS
|
from .ldm.pixart.pixartms import PixArtMS
|
||||||
from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel
|
from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel
|
||||||
from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel
|
|
||||||
from .model_management_types import ModelManageable
|
from .model_management_types import ModelManageable
|
||||||
from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG
|
from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG
|
||||||
from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \
|
from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \
|
||||||
@ -60,6 +60,8 @@ from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCo
|
|||||||
from .ops import Operations
|
from .ops import Operations
|
||||||
from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
|
from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
@ -149,8 +151,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
if model_management.force_channels_last():
|
if model_management.force_channels_last():
|
||||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
logging.debug("using channels last mode for diffusion model")
|
logger.debug("using channels last mode for diffusion model")
|
||||||
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
logger.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||||
else:
|
else:
|
||||||
self.operations = None
|
self.operations = None
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
@ -161,8 +163,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
|
|
||||||
self.concat_keys = ()
|
self.concat_keys = ()
|
||||||
logging.debug("model_type {}".format(model_type.name))
|
logger.debug("model_type {}".format(model_type.name))
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
logger.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
self.memory_usage_factor_conds = ()
|
self.memory_usage_factor_conds = ()
|
||||||
self.training = False
|
self.training = False
|
||||||
@ -310,10 +312,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
logging.warning("unet missing: {}".format(m))
|
logger.warning("unet missing: {}".format(m))
|
||||||
|
|
||||||
if len(u) > 0:
|
if len(u) > 0:
|
||||||
logging.warning("unet unexpected: {}".format(u))
|
logger.warning("unet unexpected: {}".format(u))
|
||||||
del to_load
|
del to_load
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -1227,6 +1229,7 @@ class WAN21_Camera(WAN21):
|
|||||||
out['camera_conditions'] = conds.CONDRegular(camera_conditions)
|
out['camera_conditions'] = conds.CONDRegular(camera_conditions)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class WAN22(BaseModel):
|
class WAN22(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=WanModel)
|
||||||
@ -1252,6 +1255,7 @@ class WAN22(BaseModel):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model)
|
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model)
|
||||||
@ -1321,6 +1325,7 @@ class ACEStep(BaseModel):
|
|||||||
out['lyrics_strength'] = conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
out['lyrics_strength'] = conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Omnigen2(BaseModel):
|
class Omnigen2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=OmniGen2Transformer2DModel)
|
super().__init__(model_config, model_type, device=device, unet_model=OmniGen2Transformer2DModel)
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from pathlib import Path
|
|||||||
from typing import List, Optional, Final, Set
|
from typing import List, Optional, Final, Set
|
||||||
|
|
||||||
# enable better transfer
|
# enable better transfer
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
|
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
|
||||||
@ -486,6 +486,7 @@ KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
|||||||
HuggingFile("comfyanonymous/cosmos_1.0_text_encoder_and_VAE_ComfyUI", "vae/cosmos_cv8x8x8_1.0.safetensors"),
|
HuggingFile("comfyanonymous/cosmos_1.0_text_encoder_and_VAE_ComfyUI", "vae/cosmos_cv8x8x8_1.0.safetensors"),
|
||||||
HuggingFile("Comfy-Org/Lumina_Image_2.0_Repackaged", "split_files/vae/ae.safetensors", save_with_filename="lumina_image_2.0-ae.safetensors"),
|
HuggingFile("Comfy-Org/Lumina_Image_2.0_Repackaged", "split_files/vae/ae.safetensors", save_with_filename="lumina_image_2.0-ae.safetensors"),
|
||||||
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/vae/wan_2.1_vae.safetensors"),
|
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/vae/wan_2.1_vae.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/vae/wan2.2_vae.safetensors"),
|
||||||
], folder_name="vae")
|
], folder_name="vae")
|
||||||
|
|
||||||
KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
|
KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
|
||||||
@ -546,6 +547,15 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
|||||||
HuggingFile("Comfy-Org/Cosmos_Predict2_repackaged", "cosmos_predict2_2B_video2world_480p_16fps.safetensors"),
|
HuggingFile("Comfy-Org/Cosmos_Predict2_repackaged", "cosmos_predict2_2B_video2world_480p_16fps.safetensors"),
|
||||||
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_vace_14B_fp16.safetensors"),
|
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_vace_14B_fp16.safetensors"),
|
||||||
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_fun_camera_v1.1_1.3B_bf16.safetensors"),
|
HuggingFile("Comfy-Org/Wan_2.1_ComfyUI_repackaged", "split_files/diffusion_models/wan2.1_fun_camera_v1.1_1.3B_bf16.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp16.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp16.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp16.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp16.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"),
|
||||||
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
|
||||||
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
|
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
|
||||||
], folder_names=["diffusion_models", "unet"])
|
], folder_names=["diffusion_models", "unet"])
|
||||||
|
|
||||||
|
|||||||
@ -742,7 +742,6 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
|
|||||||
|
|
||||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
logger.debug(f"Loaded {loaded_model}")
|
|
||||||
|
|
||||||
span = get_current_span()
|
span = get_current_span()
|
||||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from typing import Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from humanize import naturalsize
|
from humanize import naturalsize
|
||||||
|
from natsort import natsorted
|
||||||
|
|
||||||
from . import model_management, lora
|
from . import model_management, lora
|
||||||
from . import patcher_extension
|
from . import patcher_extension
|
||||||
@ -119,6 +120,7 @@ def wipe_lowvram_weight(m):
|
|||||||
if hasattr(m, "bias_function"):
|
if hasattr(m, "bias_function"):
|
||||||
m.bias_function = []
|
m.bias_function = []
|
||||||
|
|
||||||
|
|
||||||
def move_weight_functions(m, device):
|
def move_weight_functions(m, device):
|
||||||
if device is None:
|
if device is None:
|
||||||
return 0
|
return 0
|
||||||
@ -289,7 +291,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
return self._force_cast_weights
|
return self._force_cast_weights
|
||||||
|
|
||||||
@force_cast_weights.setter
|
@force_cast_weights.setter
|
||||||
def force_cast_weights(self, value:bool) -> None:
|
def force_cast_weights(self, value: bool) -> None:
|
||||||
self._force_cast_weights = value
|
self._force_cast_weights = value
|
||||||
|
|
||||||
def lowvram_patch_counter(self):
|
def lowvram_patch_counter(self):
|
||||||
@ -475,7 +477,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
self.add_object_patch("manual_cast_dtype", dtype)
|
self.add_object_patch("manual_cast_dtype", dtype)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.force_cast_weights = True
|
self.force_cast_weights = True
|
||||||
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
self.patches_uuid = uuid.uuid4() # TODO: optimize by preventing a full model reload for this
|
||||||
|
|
||||||
def add_weight_wrapper(self, name, function):
|
def add_weight_wrapper(self, name, function):
|
||||||
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
||||||
@ -630,7 +632,6 @@ class ModelPatcher(ModelManageable):
|
|||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
|
|
||||||
def _load_list(self) -> list[LoadingListItem]:
|
def _load_list(self) -> list[LoadingListItem]:
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
@ -715,6 +716,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
mem_counter += move_weight_functions(m, device_to)
|
mem_counter += move_weight_functions(m, device_to)
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
load_completely.sort(reverse=True)
|
||||||
|
models_loaded_regularly: list[str] = []
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
n = x.name
|
n = x.name
|
||||||
m = x.module
|
m = x.module
|
||||||
@ -726,17 +728,17 @@ class ModelPatcher(ModelManageable):
|
|||||||
for param in params:
|
for param in params:
|
||||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
||||||
|
|
||||||
logger.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
models_loaded_regularly.append("name={} module={}".format(n, m))
|
||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
|
logger.debug("lowvram: loaded module regularly: {}".format(", ".join(models_loaded_regularly)))
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
x.module.to(device_to)
|
x.module.to(device_to)
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logger.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logger.debug("loaded partially lowvram_model_memory={}MB mem_counter={}MB patch_counter={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
self._memory_measurements.model_lowvram = True
|
self._memory_measurements.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logger.debug("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logger.debug("loaded completely lowvram_model_memory={}MB mem_counter={}MB full_load={}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self._memory_measurements.model_lowvram = False
|
self._memory_measurements.model_lowvram = False
|
||||||
if full_load:
|
if full_load:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
@ -812,6 +814,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
self.object_patches_backup.clear()
|
self.object_patches_backup.clear()
|
||||||
|
|
||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
|
freed_layers: list[str] = []
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
hooks_unpatched = False
|
hooks_unpatched = False
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
@ -867,7 +870,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
freed_layers.append(n)
|
||||||
|
|
||||||
|
logger.debug("freed {}".format(natsorted(freed_layers)))
|
||||||
|
|
||||||
self._memory_measurements.model_lowvram = True
|
self._memory_measurements.model_lowvram = True
|
||||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||||
@ -1190,7 +1195,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
model_sd_keys_set = set(model_sd_keys)
|
model_sd_keys_set = set(model_sd_keys)
|
||||||
for key in cached_weights:
|
for key in cached_weights:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
logger.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||||
model_sd_keys_set.remove(key)
|
model_sd_keys_set.remove(key)
|
||||||
@ -1203,7 +1208,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
original_weights = self.get_key_patches()
|
original_weights = self.get_key_patches()
|
||||||
for key in relevant_patches:
|
for key in relevant_patches:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
|
logger.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||||
memory_counter=memory_counter)
|
memory_counter=memory_counter)
|
||||||
@ -1265,7 +1270,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
del out_weight
|
del out_weight
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str] = None) -> None:
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
if len(self.hook_backup) == 0:
|
if len(self.hook_backup) == 0:
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
|
|||||||
@ -7,9 +7,10 @@ from PIL import Image
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from .component_model.executor_types import ExecutorToClientProgress
|
||||||
from .component_model.module_property import create_module_properties
|
from .component_model.module_property import create_module_properties
|
||||||
from .execution_context import current_execution_context
|
from .execution_context import current_execution_context
|
||||||
from .progress_types import AbstractProgressRegistry
|
from .progress_types import AbstractProgressRegistry, PreviewImageMetadata
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
@ -157,7 +158,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
Handler that sends progress updates to the WebUI via WebSockets.
|
Handler that sends progress updates to the WebUI via WebSockets.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, server_instance):
|
def __init__(self, server_instance: ExecutorToClientProgress):
|
||||||
super().__init__("webui")
|
super().__init__("webui")
|
||||||
self.server_instance = server_instance
|
self.server_instance = server_instance
|
||||||
|
|
||||||
@ -216,7 +217,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
self.server_instance.client_id,
|
self.server_instance.client_id,
|
||||||
"supports_preview_metadata",
|
"supports_preview_metadata",
|
||||||
):
|
):
|
||||||
metadata = {
|
metadata: PreviewImageMetadata = {
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
||||||
@ -327,7 +328,7 @@ class ProgressRegistry(AbstractProgressRegistry):
|
|||||||
|
|
||||||
# Global registry instance
|
# Global registry instance
|
||||||
@_module_properties.getter
|
@_module_properties.getter
|
||||||
def _global_progress_registry() -> ProgressRegistry:
|
def _global_progress_registry() -> AbstractProgressRegistry | None:
|
||||||
return current_execution_context().progress_registry
|
return current_execution_context().progress_registry
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
41
comfy/sd.py
41
comfy/sd.py
@ -41,23 +41,23 @@ from .model_management import load_models_gpu
|
|||||||
from .model_patcher import ModelPatcher
|
from .model_patcher import ModelPatcher
|
||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
from .taesd import taesd
|
from .taesd import taesd
|
||||||
|
from .text_encoders import ace
|
||||||
from .text_encoders import aura_t5
|
from .text_encoders import aura_t5
|
||||||
from .text_encoders import hidream
|
|
||||||
from .text_encoders import cosmos
|
from .text_encoders import cosmos
|
||||||
from .text_encoders import flux
|
from .text_encoders import flux
|
||||||
from .text_encoders import genmo
|
from .text_encoders import genmo
|
||||||
|
from .text_encoders import hidream
|
||||||
from .text_encoders import hunyuan_video
|
from .text_encoders import hunyuan_video
|
||||||
from .text_encoders import hydit
|
from .text_encoders import hydit
|
||||||
from .text_encoders import long_clipl
|
from .text_encoders import long_clipl
|
||||||
from .text_encoders import lt
|
from .text_encoders import lt
|
||||||
from .text_encoders import lumina2
|
from .text_encoders import lumina2
|
||||||
|
from .text_encoders import omnigen2
|
||||||
from .text_encoders import pixart_t5
|
from .text_encoders import pixart_t5
|
||||||
from .text_encoders import sa_t5
|
from .text_encoders import sa_t5
|
||||||
from .text_encoders import sd2_clip
|
from .text_encoders import sd2_clip
|
||||||
from .text_encoders import sd3_clip
|
from .text_encoders import sd3_clip
|
||||||
from .text_encoders import wan
|
from .text_encoders import wan
|
||||||
from .text_encoders import ace
|
|
||||||
from .text_encoders import omnigen2
|
|
||||||
from .utils import ProgressBar
|
from .utils import ProgressBar
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -280,7 +280,9 @@ class CLIP:
|
|||||||
|
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None, no_init=False):
|
||||||
|
if no_init:
|
||||||
|
return
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
@ -469,7 +471,7 @@ class VAE:
|
|||||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||||
self.first_stage_model = ShapeVAE(**ddconfig)
|
self.first_stage_model = ShapeVAE(**ddconfig)
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: # Ace Step Audio
|
||||||
self.first_stage_model = MusicDCAE(source_sample_rate=44100)
|
self.first_stage_model = MusicDCAE(source_sample_rate=44100)
|
||||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||||
@ -511,6 +513,29 @@ class VAE:
|
|||||||
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
n = VAE(no_init=True)
|
||||||
|
n.memory_used_encode = self.memory_used_encode
|
||||||
|
n.memory_used_decode = self.memory_used_decode
|
||||||
|
n.downscale_ratio = self.downscale_ratio
|
||||||
|
n.upscale_ratio = self.upscale_ratio
|
||||||
|
n.latent_channels = self.latent_channels
|
||||||
|
n.latent_dim = self.latent_dim
|
||||||
|
n.output_channels = self.output_channels
|
||||||
|
n.process_input = self.process_input
|
||||||
|
n.process_output = self.process_output
|
||||||
|
n.working_dtypes = self.working_dtypes.copy()
|
||||||
|
n.disable_offload = self.disable_offload
|
||||||
|
n.downscale_index_formula = self.downscale_index_formula
|
||||||
|
n.upscale_index_formula = self.upscale_index_formula
|
||||||
|
n.extra_1d_channel = self.extra_1d_channel
|
||||||
|
n.first_stage_model = self.first_stage_model
|
||||||
|
n.device = self.device
|
||||||
|
n.vae_dtype = self.vae_dtype
|
||||||
|
n.output_device = self.output_device
|
||||||
|
n.patcher = self.patcher.clone()
|
||||||
|
return n
|
||||||
|
|
||||||
def throw_exception_if_invalid(self):
|
def throw_exception_if_invalid(self):
|
||||||
if self.first_stage_model is None:
|
if self.first_stage_model is None:
|
||||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||||
@ -920,7 +945,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = hidream.hidream_clip(**t5xxl_detect(clip_data),
|
clip_target.clip = hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||||
clip_target.tokenizer = hidream.HiDreamTokenizer
|
clip_target.tokenizer = hidream.HiDreamTokenizer
|
||||||
else: # CLIPType.MOCHI
|
else: # CLIPType.MOCHI
|
||||||
clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
@ -945,7 +970,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif te_model == TEModel.LLAMA3_8:
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
clip_target.clip = hidream.hidream_clip(**llama_detect(clip_data),
|
clip_target.clip = hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||||
clip_target.tokenizer = hidream.HiDreamTokenizer
|
clip_target.tokenizer = hidream.HiDreamTokenizer
|
||||||
elif te_model == TEModel.QWEN25_3B:
|
elif te_model == TEModel.QWEN25_3B:
|
||||||
clip_target.clip = omnigen2.te(**llama_detect(clip_data))
|
clip_target.clip = omnigen2.te(**llama_detect(clip_data))
|
||||||
@ -1033,6 +1058,7 @@ def model_detection_error_hint(path, state_dict):
|
|||||||
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
|
return "\nHINT: This seems to be a Lora file and Lora files should be put in the lora folder and loaded with a lora loader node.."
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||||
logger.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
logger.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
||||||
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
||||||
@ -1097,7 +1123,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return None
|
return None
|
||||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||||
|
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.scaled_fp8 is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from einops import rearrange
|
|||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from comfy_api import feature_flags
|
||||||
from . import interruption, checkpoint_pickle
|
from . import interruption, checkpoint_pickle
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from .component_model import files
|
from .component_model import files
|
||||||
@ -48,6 +49,7 @@ from .component_model.deprecation import _deprecate_method
|
|||||||
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||||
from .component_model.queue_types import BinaryEventTypes
|
from .component_model.queue_types import BinaryEventTypes
|
||||||
from .execution_context import current_execution_context
|
from .execution_context import current_execution_context
|
||||||
|
from .progress import get_progress_state
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
DISABLE_MMAP = args.disable_mmap
|
DISABLE_MMAP = args.disable_mmap
|
||||||
@ -1106,22 +1108,23 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
|
|||||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||||
|
|
||||||
|
|
||||||
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None, node_id: str = None):
|
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None, node_id: str = None, prompt_id: str = None):
|
||||||
server = server or current_execution_context().server
|
context = current_execution_context()
|
||||||
# todo: this should really be from the context. right now the server is behaving like a context
|
server = server or context.server
|
||||||
client_id = client_id or server.client_id
|
executing_context = context
|
||||||
|
prompt_id = prompt_id or executing_context.task_id or server.last_prompt_id
|
||||||
|
node_id = node_id or executing_context.node_id or server.last_node_id
|
||||||
interruption.throw_exception_if_processing_interrupted()
|
interruption.throw_exception_if_processing_interrupted()
|
||||||
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": node_id or server.last_node_id}
|
|
||||||
|
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||||
|
# todo: is this still necessary?
|
||||||
if isinstance(preview_image_or_data, dict):
|
if isinstance(preview_image_or_data, dict):
|
||||||
progress["output"] = preview_image_or_data
|
progress["output"] = preview_image_or_data
|
||||||
|
|
||||||
|
# this is responsible for encoding the image
|
||||||
|
get_progress_state().update_progress(node_id, value, total, preview_image_or_data)
|
||||||
server.send_sync("progress", progress, client_id)
|
server.send_sync("progress", progress, client_id)
|
||||||
|
|
||||||
# todo: investigate a better way to send the image data, since it needs the node ID
|
|
||||||
if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict):
|
|
||||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id)
|
|
||||||
|
|
||||||
|
|
||||||
def set_progress_bar_enabled(enabled: bool):
|
def set_progress_bar_enabled(enabled: bool):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
||||||
|
|||||||
@ -43,7 +43,8 @@ def get_connection_feature(
|
|||||||
def supports_feature(
|
def supports_feature(
|
||||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||||
sid: str,
|
sid: str,
|
||||||
feature_name: str
|
feature_name: str,
|
||||||
|
force=True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a connection supports a specific feature.
|
Check if a connection supports a specific feature.
|
||||||
@ -52,10 +53,13 @@ def supports_feature(
|
|||||||
sockets_metadata: Dictionary of socket metadata
|
sockets_metadata: Dictionary of socket metadata
|
||||||
sid: Session ID of the connection
|
sid: Session ID of the connection
|
||||||
feature_name: Name of the feature to check
|
feature_name: Name of the feature to check
|
||||||
|
force (bool): If it cannot be determined, assume True
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Boolean indicating if feature is supported
|
Boolean indicating if feature is supported
|
||||||
"""
|
"""
|
||||||
|
if sockets_metadata is None or "__unimplemented" in sockets_metadata:
|
||||||
|
return force
|
||||||
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from comfy import model_management
|
|||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||||
|
from comfy.sd import VAE
|
||||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -45,6 +46,7 @@ def write_atomic(
|
|||||||
|
|
||||||
|
|
||||||
torch._inductor.codecache.write_atomic = write_atomic
|
torch._inductor.codecache.write_atomic = write_atomic
|
||||||
|
# torch._inductor.utils.is_big_gpu = lambda *args: True
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileModel(CustomNode):
|
class TorchCompileModel(CustomNode):
|
||||||
@ -52,10 +54,10 @@ class TorchCompileModel(CustomNode):
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model": ("MODEL",),
|
"model": ("MODEL,VAE",),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
|
"object_patch": ("STRING", {"default": ""}),
|
||||||
"fullgraph": ("BOOLEAN", {"default": False}),
|
"fullgraph": ("BOOLEAN", {"default": False}),
|
||||||
"dynamic": ("BOOLEAN", {"default": False}),
|
"dynamic": ("BOOLEAN", {"default": False}),
|
||||||
"backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}),
|
"backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}),
|
||||||
@ -64,15 +66,14 @@ class TorchCompileModel(CustomNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL,VAE",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
RETURN_NAMES = ("model or vae",)
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[Callable]:
|
def patch(self, model: ModelPatcher | VAE | torch.nn.Module, object_patch: str | None = "", fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[Callable]:
|
||||||
if object_patch is None:
|
|
||||||
object_patch = DIFFUSION_MODEL
|
|
||||||
compile_kwargs = {
|
compile_kwargs = {
|
||||||
"fullgraph": fullgraph,
|
"fullgraph": fullgraph,
|
||||||
"dynamic": dynamic,
|
"dynamic": dynamic,
|
||||||
@ -99,17 +100,26 @@ class TorchCompileModel(CustomNode):
|
|||||||
}
|
}
|
||||||
move_to_gpu = True
|
move_to_gpu = True
|
||||||
del compile_kwargs["mode"]
|
del compile_kwargs["mode"]
|
||||||
if isinstance(model, ModelPatcher) or isinstance(model, TransformersManagedModel):
|
if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)):
|
||||||
m = model.clone()
|
to_return = model.clone()
|
||||||
|
object_patches = [p.strip() for p in object_patch.split(",")]
|
||||||
|
patcher: ModelPatcher
|
||||||
|
if isinstance(to_return, VAE):
|
||||||
|
patcher = to_return.patcher
|
||||||
|
object_patches = ["encoder", "decoder"]
|
||||||
|
else:
|
||||||
|
patcher = to_return
|
||||||
|
if object_patch is None or len(object_patches) == 0:
|
||||||
|
object_patches = [DIFFUSION_MODEL]
|
||||||
if move_to_gpu:
|
if move_to_gpu:
|
||||||
model_management.unload_all_models()
|
model_management.unload_all_models()
|
||||||
model_management.load_models_gpu([m])
|
model_management.load_models_gpu([patcher])
|
||||||
set_torch_compile_wrapper(m, object_patch=object_patch, **compile_kwargs)
|
set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs)
|
||||||
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
# m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
||||||
# todo: do we want to move something back off the GPU?
|
# todo: do we want to move something back off the GPU?
|
||||||
# if move_to_gpu:
|
# if move_to_gpu:
|
||||||
# model_management.unload_all_models()
|
# model_management.unload_all_models()
|
||||||
return m,
|
return to_return,
|
||||||
elif isinstance(model, torch.nn.Module):
|
elif isinstance(model, torch.nn.Module):
|
||||||
if move_to_gpu:
|
if move_to_gpu:
|
||||||
model_management.unload_all_models()
|
model_management.unload_all_models()
|
||||||
@ -119,7 +129,7 @@ class TorchCompileModel(CustomNode):
|
|||||||
model.to(device=model_management.unet_offload_device())
|
model.to(device=model_management.unet_offload_device())
|
||||||
return res,
|
return res,
|
||||||
else:
|
else:
|
||||||
logging.warning("Encountered a model that cannot be compiled")
|
logger.warning("Encountered a model that cannot be compiled")
|
||||||
return model,
|
return model,
|
||||||
except OSError as os_error:
|
except OSError as os_error:
|
||||||
try:
|
try:
|
||||||
@ -132,7 +142,7 @@ class TorchCompileModel(CustomNode):
|
|||||||
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
|
torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logging.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info)
|
logger.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info)
|
||||||
return model,
|
return model,
|
||||||
|
|
||||||
|
|
||||||
@ -160,7 +170,7 @@ class QuantizeModel(CustomNode):
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
|
||||||
def warn_in_place(self, model: ModelPatcher):
|
def warn_in_place(self, model: ModelPatcher):
|
||||||
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
|
logger.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
|
||||||
|
|
||||||
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
|
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
@ -179,7 +189,7 @@ class QuantizeModel(CustomNode):
|
|||||||
"final_layer",
|
"final_layer",
|
||||||
}
|
}
|
||||||
if strategy == "quanto":
|
if strategy == "quanto":
|
||||||
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
|
logger.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
|
||||||
self.warn_in_place(model)
|
self.warn_in_place(model)
|
||||||
from optimum.quanto import quantize, qint8 # pylint: disable=import-error
|
from optimum.quanto import quantize, qint8 # pylint: disable=import-error
|
||||||
exclusion_list = [
|
exclusion_list = [
|
||||||
|
|||||||
@ -57,7 +57,7 @@ dependencies = [
|
|||||||
"pyjwt[crypto]",
|
"pyjwt[crypto]",
|
||||||
"kornia>=0.7.0",
|
"kornia>=0.7.0",
|
||||||
"mpmath>=1.0,!=1.4.0a0",
|
"mpmath>=1.0,!=1.4.0a0",
|
||||||
"huggingface_hub[hf_transfer]>0.20",
|
"huggingface_hub[hf_xet]>=0.32.0",
|
||||||
"lazy-object-proxy",
|
"lazy-object-proxy",
|
||||||
"lazy_loader>=0.3",
|
"lazy_loader>=0.3",
|
||||||
"can_ada",
|
"can_ada",
|
||||||
@ -76,7 +76,8 @@ dependencies = [
|
|||||||
"wrapt>=1.16.0",
|
"wrapt>=1.16.0",
|
||||||
"certifi",
|
"certifi",
|
||||||
"spandrel>=0.3.4",
|
"spandrel>=0.3.4",
|
||||||
"numpy>=1.24.4",
|
# https://github.com/conda-forge/numba-feedstock/issues/158 until numba is released with support for a later version of numpy
|
||||||
|
"numpy>=1.24.4,<2.3",
|
||||||
"soundfile",
|
"soundfile",
|
||||||
"watchdog",
|
"watchdog",
|
||||||
"PySoundFile",
|
"PySoundFile",
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import requests
|
|||||||
|
|
||||||
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
|
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
|
||||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
||||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||||
os.environ["TC_HOST"] = "localhost"
|
os.environ["TC_HOST"] = "localhost"
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user