mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Fix issues with tests
This commit is contained in:
parent
2003e6ae65
commit
ffc1912eff
@ -20,6 +20,8 @@ from ..cmd.folder_paths import add_model_folder_path # pylint: disable=import-e
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
def check_frontend_version():
|
||||
return None
|
||||
|
||||
class Asset(TypedDict):
|
||||
url: str
|
||||
@ -162,7 +164,7 @@ class FrontendManager:
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
# check_frontend_version()
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
@ -224,5 +226,5 @@ class FrontendManager:
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
# check_frontend_version()
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
@ -85,7 +85,7 @@ class Hunyuan3Dv2(nn.Module):
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
def block_wrap1(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
@ -99,7 +99,7 @@ class Hunyuan3Dv2(nn.Module):
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
{"original_block": block_wrap1})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
|
||||
@ -22,11 +22,9 @@ else:
|
||||
sageattn = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
if model_management.flash_attention_enabled():
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ModuleNotFoundError:
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
from flash_attn import flash_attn_func # pylint: disable=import-error
|
||||
else:
|
||||
flash_attn_func = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
from ...cli_args import args
|
||||
from ... import ops
|
||||
@ -546,7 +544,7 @@ try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) # pylint: disable=possibly-used-before-assignment,used-before-assignment
|
||||
|
||||
|
||||
@flash_attn_wrapper.register_fake
|
||||
|
||||
@ -40,7 +40,7 @@ from .ldm.hunyuan_video.model import HunyuanVideo as HunyuanVideoModel
|
||||
from .ldm.hydit.models import HunYuanDiT
|
||||
from .ldm.lightricks.model import LTXVModel
|
||||
from .ldm.lumina.model import NextDiT
|
||||
from .ldm.hunyuan3d.model import Hunyuan3Dv2
|
||||
from .ldm.hunyuan3d.model import Hunyuan3Dv2 as Hunyuan3Dv2Model
|
||||
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
@ -64,7 +64,7 @@ class ModelType(Enum):
|
||||
IMG_TO_IMG = 9
|
||||
|
||||
|
||||
from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux
|
||||
from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux, IMG_TO_IMG
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -94,7 +94,7 @@ def model_sampling(model_config, model_type):
|
||||
c = CONST
|
||||
s = ModelSamplingFlux
|
||||
elif model_type == ModelType.IMG_TO_IMG:
|
||||
c = model_sampling.IMG_TO_IMG
|
||||
c = IMG_TO_IMG
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -1081,7 +1081,7 @@ class WAN21(BaseModel):
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2)
|
||||
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
@ -1103,7 +1103,7 @@ def flash_attn_enabled():
|
||||
return False
|
||||
if directml_device:
|
||||
return False
|
||||
return FLASH_ATTENTION_ENABLED
|
||||
return flash_attention_enabled()
|
||||
|
||||
|
||||
def xformers_enabled_vae():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user