Fix pylint issues

This commit is contained in:
doctorpangloss 2025-09-03 12:13:35 -07:00
parent 179c2d35c8
commit 8052d070de
9 changed files with 19 additions and 17 deletions

View File

@ -1,5 +1,6 @@
from .wav2vec2 import Wav2Vec2Model
from ..model_management import text_encoder_offload_device, text_encoder_device, load_model_gpu, text_encoder_dtype
from ..model_patcher import ModelPatcher
from ..ops import manual_cast
from ..utils import state_dict_prefix_replace
@ -12,7 +13,7 @@ class AudioEncoderModel():
self.dtype = text_encoder_dtype(self.load_device)
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
def load_sd(self, sd):

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention_masked
from ..ldm.modules.attention import optimized_attention_masked
class LayerNormConv(nn.Module):

View File

@ -3,7 +3,7 @@ from typing import List, NamedTuple, Optional
from typing_extensions import TypedDict, Literal, NotRequired
from comfy.component_model.executor_types import SendSyncEvent, SendSyncData
from ..component_model.executor_types import SendSyncEvent, SendSyncData
class FileOutput(TypedDict, total=False):

View File

@ -8,7 +8,6 @@ import logging
import threading
import uuid
from asyncio import get_event_loop
from dataclasses import dataclass
from multiprocessing import RLock
from typing import Optional, Generator
@ -16,8 +15,8 @@ from opentelemetry import context, propagate
from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode
from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress
from ..cmd.main_pre import tracer
from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration

View File

@ -953,7 +953,7 @@ class MotionEncoder_tc(nn.Module):
x = self.norm3(x)
x = self.act(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
padding = cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
@ -1005,7 +1005,7 @@ class CausalAudioEncoder(nn.Module):
def forward(self, features):
# features B * num_layers * dim * video_length
weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device))
weights = self.act(cast_to(self.weights, dtype=features.dtype, device=features.device))
weights_sum = weights.sum(dim=1, keepdims=True)
weighted_feat = ((features * weights) / weights_sum).sum(
dim=1) # b dim f
@ -1267,7 +1267,7 @@ class WanModel_S2V(WanModel):
x = x.flatten(2).transpose(1, 2)
seq_len = x.size(1)
cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
cond_mask_weight = cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
x = x + cond_mask_weight[0]
if reference_latent is not None:

View File

@ -52,7 +52,7 @@ from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentati
from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel
from .ldm.pixart.pixartms import PixArtMS
from .ldm.qwen_image.model import QwenImageTransformer2DModel
from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel
from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel, WanModel_S2V
from .model_management_types import ModelManageable
from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG
from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \
@ -1258,7 +1258,7 @@ class WAN21_Camera(WAN21):
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=WanModel_S2V)
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
@ -1266,19 +1266,19 @@ class WAN22_S2V(WAN21):
out = super().extra_conds(**kwargs)
audio_embed = kwargs.get("audio_embed", None)
if audio_embed is not None:
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
out['audio_embed'] = conds.CONDRegular(audio_embed)
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
out['reference_latent'] = conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
reference_motion = kwargs.get("reference_motion", None)
if reference_motion is not None:
out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion))
out['reference_motion'] = conds.CONDRegular(self.process_latent_in(reference_motion))
control_video = kwargs.get("control_video", None)
if control_video is not None:
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
out['control_video'] = conds.CONDRegular(self.process_latent_in(control_video))
return out
def extra_conds_shapes(self, **kwargs):

View File

@ -771,7 +771,7 @@ class Sampler:
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
@_module_properties.getter()
@_module_properties.getter
def _KSAMPLER_NAMES():
return KSAMPLER_NAMES

View File

@ -1,5 +1,5 @@
import torch
import nodes
from comfy.nodes import base_nodes as nodes
import comfy.utils

View File

@ -6,6 +6,7 @@ import comfy.ops
import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
from comfy.model_patcher import ModelPatcher
class BlockWiseControlBlock(torch.nn.Module):
@ -207,6 +208,7 @@ class ModelPatchLoader:
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
model = None
if 'controlnet_blocks.0.y_rms.weight' in sd:
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
@ -215,7 +217,7 @@ class ModelPatchLoader:
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
model = ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)