diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index df355acc8..3e0432711 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -1,5 +1,6 @@ from .wav2vec2 import Wav2Vec2Model from ..model_management import text_encoder_offload_device, text_encoder_device, load_model_gpu, text_encoder_dtype +from ..model_patcher import ModelPatcher from ..ops import manual_cast from ..utils import state_dict_prefix_replace @@ -12,7 +13,7 @@ class AudioEncoderModel(): self.dtype = text_encoder_dtype(self.load_device) self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=manual_cast) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 def load_sd(self, sd): diff --git a/comfy/audio_encoders/wav2vec2.py b/comfy/audio_encoders/wav2vec2.py index de906622a..43c503bf8 100644 --- a/comfy/audio_encoders/wav2vec2.py +++ b/comfy/audio_encoders/wav2vec2.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from comfy.ldm.modules.attention import optimized_attention_masked +from ..ldm.modules.attention import optimized_attention_masked class LayerNormConv(nn.Module): diff --git a/comfy/client/client_types.py b/comfy/client/client_types.py index d5fd6b1c4..7f1d55ec2 100644 --- a/comfy/client/client_types.py +++ b/comfy/client/client_types.py @@ -3,7 +3,7 @@ from typing import List, NamedTuple, Optional from typing_extensions import TypedDict, Literal, NotRequired -from comfy.component_model.executor_types import SendSyncEvent, SendSyncData +from ..component_model.executor_types import SendSyncEvent, SendSyncData class FileOutput(TypedDict, total=False): diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index d3af0acd6..e5d641c20 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -8,7 +8,6 @@ import logging import threading import uuid from asyncio import get_event_loop -from dataclasses import dataclass from multiprocessing import RLock from typing import Optional, Generator @@ -16,8 +15,8 @@ from opentelemetry import context, propagate from opentelemetry.context import Context, attach, detach from opentelemetry.trace import Status, StatusCode -from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress from ..cmd.main_pre import tracer +from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress from .client_types import V1QueuePromptResponse from ..api.components.schema.prompt import PromptDict from ..cli_args_types import Configuration diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index fcfa063e1..d3fdfe3eb 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -953,7 +953,7 @@ class MotionEncoder_tc(nn.Module): x = self.norm3(x) x = self.act(x) x = rearrange(x, '(b n) t c -> b t n c', b=b) - padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1) + padding = cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1) x = torch.cat([x, padding], dim=-2) x_local = x.clone() @@ -1005,7 +1005,7 @@ class CausalAudioEncoder(nn.Module): def forward(self, features): # features B * num_layers * dim * video_length - weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device)) + weights = self.act(cast_to(self.weights, dtype=features.dtype, device=features.device)) weights_sum = weights.sum(dim=1, keepdims=True) weighted_feat = ((features * weights) / weights_sum).sum( dim=1) # b dim f @@ -1267,7 +1267,7 @@ class WanModel_S2V(WanModel): x = x.flatten(2).transpose(1, 2) seq_len = x.size(1) - cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1) + cond_mask_weight = cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1) x = x + cond_mask_weight[0] if reference_latent is not None: diff --git a/comfy/model_base.py b/comfy/model_base.py index f8aeb62ef..e98403e09 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -52,7 +52,7 @@ from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentati from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel from .ldm.pixart.pixartms import PixArtMS from .ldm.qwen_image.model import QwenImageTransformer2DModel -from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel +from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel, WanModel_S2V from .model_management_types import ModelManageable from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \ @@ -1258,7 +1258,7 @@ class WAN21_Camera(WAN21): class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=WanModel_S2V) self.memory_usage_factor_conds = ("reference_latent", "reference_motion") self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} @@ -1266,19 +1266,19 @@ class WAN22_S2V(WAN21): out = super().extra_conds(**kwargs) audio_embed = kwargs.get("audio_embed", None) if audio_embed is not None: - out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) + out['audio_embed'] = conds.CONDRegular(audio_embed) reference_latents = kwargs.get("reference_latents", None) if reference_latents is not None: - out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + out['reference_latent'] = conds.CONDRegular(self.process_latent_in(reference_latents[-1])) reference_motion = kwargs.get("reference_motion", None) if reference_motion is not None: - out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion)) + out['reference_motion'] = conds.CONDRegular(self.process_latent_in(reference_motion)) control_video = kwargs.get("control_video", None) if control_video is not None: - out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video)) + out['control_video'] = conds.CONDRegular(self.process_latent_in(control_video)) return out def extra_conds_shapes(self, **kwargs): diff --git a/comfy/samplers.py b/comfy/samplers.py index 615775959..e8a6311f2 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -771,7 +771,7 @@ class Sampler: return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma -@_module_properties.getter() +@_module_properties.getter def _KSAMPLER_NAMES(): return KSAMPLER_NAMES diff --git a/comfy_extras/nodes/nodes_latent.py b/comfy_extras/nodes/nodes_latent.py index 2bd9a17d0..ae60a8ca5 100644 --- a/comfy_extras/nodes/nodes_latent.py +++ b/comfy_extras/nodes/nodes_latent.py @@ -1,5 +1,5 @@ import torch -import nodes +from comfy.nodes import base_nodes as nodes import comfy.utils diff --git a/comfy_extras/nodes/nodes_model_patch.py b/comfy_extras/nodes/nodes_model_patch.py index b127cf8a7..8723b8981 100644 --- a/comfy_extras/nodes/nodes_model_patch.py +++ b/comfy_extras/nodes/nodes_model_patch.py @@ -6,6 +6,7 @@ import comfy.ops import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats +from comfy.model_patcher import ModelPatcher class BlockWiseControlBlock(torch.nn.Module): @@ -207,6 +208,7 @@ class ModelPatchLoader: sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) dtype = comfy.utils.weight_dtype(sd) + model = None if 'controlnet_blocks.0.y_rms.weight' in sd: additional_in_dim = sd["img_in.weight"].shape[1] - 64 model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) @@ -215,7 +217,7 @@ class ModelPatchLoader: model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model.load_state_dict(sd) - model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + model = ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) return (model,)