Fix issues with tests

This commit is contained in:
doctorpangloss 2025-04-04 08:27:33 -07:00
parent 2003e6ae65
commit ffc1912eff
5 changed files with 15 additions and 15 deletions

View File

@ -20,6 +20,8 @@ from ..cmd.folder_paths import add_model_folder_path # pylint: disable=import-e
REQUEST_TIMEOUT = 10 # seconds REQUEST_TIMEOUT = 10 # seconds
def check_frontend_version():
return None
class Asset(TypedDict): class Asset(TypedDict):
url: str url: str
@ -162,7 +164,7 @@ class FrontendManager:
main error source might be request timeout or invalid URL. main error source might be request timeout or invalid URL.
""" """
if version_string == DEFAULT_VERSION_STRING: if version_string == DEFAULT_VERSION_STRING:
# check_frontend_version() check_frontend_version()
return cls.default_frontend_path() return cls.default_frontend_path()
repo_owner, repo_name, version = cls.parse_version_string(version_string) repo_owner, repo_name, version = cls.parse_version_string(version_string)
@ -224,5 +226,5 @@ class FrontendManager:
except Exception as e: except Exception as e:
logging.error("Failed to initialize frontend: %s", e) logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.") logging.info("Falling back to the default frontend.")
# check_frontend_version() check_frontend_version()
return cls.default_frontend_path() return cls.default_frontend_path()

View File

@ -85,7 +85,7 @@ class Hunyuan3Dv2(nn.Module):
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap1(args):
out = {} out = {}
out["img"], out["txt"] = block(img=args["img"], out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"], txt=args["txt"],
@ -99,7 +99,7 @@ class Hunyuan3Dv2(nn.Module):
"vec": vec, "vec": vec,
"pe": pe, "pe": pe,
"attn_mask": attn_mask}, "attn_mask": attn_mask},
{"original_block": block_wrap}) {"original_block": block_wrap1})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
else: else:

View File

@ -22,11 +22,9 @@ else:
sageattn = torch.nn.functional.scaled_dot_product_attention sageattn = torch.nn.functional.scaled_dot_product_attention
if model_management.flash_attention_enabled(): if model_management.flash_attention_enabled():
try: from flash_attn import flash_attn_func # pylint: disable=import-error
from flash_attn import flash_attn_func else:
except ModuleNotFoundError: flash_attn_func = torch.nn.functional.scaled_dot_product_attention
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
from ...cli_args import args from ...cli_args import args
from ... import ops from ... import ops
@ -546,7 +544,7 @@ try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) # pylint: disable=possibly-used-before-assignment,used-before-assignment
@flash_attn_wrapper.register_fake @flash_attn_wrapper.register_fake

View File

@ -40,7 +40,7 @@ from .ldm.hunyuan_video.model import HunyuanVideo as HunyuanVideoModel
from .ldm.hydit.models import HunYuanDiT from .ldm.hydit.models import HunYuanDiT
from .ldm.lightricks.model import LTXVModel from .ldm.lightricks.model import LTXVModel
from .ldm.lumina.model import NextDiT from .ldm.lumina.model import NextDiT
from .ldm.hunyuan3d.model import Hunyuan3Dv2 from .ldm.hunyuan3d.model import Hunyuan3Dv2 as Hunyuan3Dv2Model
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
@ -64,7 +64,7 @@ class ModelType(Enum):
IMG_TO_IMG = 9 IMG_TO_IMG = 9
from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux, IMG_TO_IMG
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
@ -94,7 +94,7 @@ def model_sampling(model_config, model_type):
c = CONST c = CONST
s = ModelSamplingFlux s = ModelSamplingFlux
elif model_type == ModelType.IMG_TO_IMG: elif model_type == ModelType.IMG_TO_IMG:
c = model_sampling.IMG_TO_IMG c = IMG_TO_IMG
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@ -1081,7 +1081,7 @@ class WAN21(BaseModel):
class Hunyuan3Dv2(BaseModel): class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2) super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model)
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)

View File

@ -1103,7 +1103,7 @@ def flash_attn_enabled():
return False return False
if directml_device: if directml_device:
return False return False
return FLASH_ATTENTION_ENABLED return flash_attention_enabled()
def xformers_enabled_vae(): def xformers_enabled_vae():