Fix issues with tests

This commit is contained in:
doctorpangloss 2025-04-04 08:27:33 -07:00
parent 2003e6ae65
commit ffc1912eff
5 changed files with 15 additions and 15 deletions

View File

@ -20,6 +20,8 @@ from ..cmd.folder_paths import add_model_folder_path # pylint: disable=import-e
REQUEST_TIMEOUT = 10 # seconds
def check_frontend_version():
return None
class Asset(TypedDict):
url: str
@ -162,7 +164,7 @@ class FrontendManager:
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
# check_frontend_version()
check_frontend_version()
return cls.default_frontend_path()
repo_owner, repo_name, version = cls.parse_version_string(version_string)
@ -224,5 +226,5 @@ class FrontendManager:
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
# check_frontend_version()
check_frontend_version()
return cls.default_frontend_path()

View File

@ -85,7 +85,7 @@ class Hunyuan3Dv2(nn.Module):
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
def block_wrap1(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
@ -99,7 +99,7 @@ class Hunyuan3Dv2(nn.Module):
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
{"original_block": block_wrap1})
txt = out["txt"]
img = out["img"]
else:

View File

@ -22,11 +22,9 @@ else:
sageattn = torch.nn.functional.scaled_dot_product_attention
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
from flash_attn import flash_attn_func # pylint: disable=import-error
else:
flash_attn_func = torch.nn.functional.scaled_dot_product_attention
from ...cli_args import args
from ... import ops
@ -546,7 +544,7 @@ try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) # pylint: disable=possibly-used-before-assignment,used-before-assignment
@flash_attn_wrapper.register_fake

View File

@ -40,7 +40,7 @@ from .ldm.hunyuan_video.model import HunyuanVideo as HunyuanVideoModel
from .ldm.hydit.models import HunYuanDiT
from .ldm.lightricks.model import LTXVModel
from .ldm.lumina.model import NextDiT
from .ldm.hunyuan3d.model import Hunyuan3Dv2
from .ldm.hunyuan3d.model import Hunyuan3Dv2 as Hunyuan3Dv2Model
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
@ -64,7 +64,7 @@ class ModelType(Enum):
IMG_TO_IMG = 9
from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux
from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux, IMG_TO_IMG
def model_sampling(model_config, model_type):
@ -94,7 +94,7 @@ def model_sampling(model_config, model_type):
c = CONST
s = ModelSamplingFlux
elif model_type == ModelType.IMG_TO_IMG:
c = model_sampling.IMG_TO_IMG
c = IMG_TO_IMG
class ModelSampling(s, c):
pass
@ -1081,7 +1081,7 @@ class WAN21(BaseModel):
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2)
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)

View File

@ -1103,7 +1103,7 @@ def flash_attn_enabled():
return False
if directml_device:
return False
return FLASH_ATTENTION_ENABLED
return flash_attention_enabled()
def xformers_enabled_vae():