mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Fix issues with tests
This commit is contained in:
parent
2003e6ae65
commit
ffc1912eff
@ -20,6 +20,8 @@ from ..cmd.folder_paths import add_model_folder_path # pylint: disable=import-e
|
|||||||
|
|
||||||
REQUEST_TIMEOUT = 10 # seconds
|
REQUEST_TIMEOUT = 10 # seconds
|
||||||
|
|
||||||
|
def check_frontend_version():
|
||||||
|
return None
|
||||||
|
|
||||||
class Asset(TypedDict):
|
class Asset(TypedDict):
|
||||||
url: str
|
url: str
|
||||||
@ -162,7 +164,7 @@ class FrontendManager:
|
|||||||
main error source might be request timeout or invalid URL.
|
main error source might be request timeout or invalid URL.
|
||||||
"""
|
"""
|
||||||
if version_string == DEFAULT_VERSION_STRING:
|
if version_string == DEFAULT_VERSION_STRING:
|
||||||
# check_frontend_version()
|
check_frontend_version()
|
||||||
return cls.default_frontend_path()
|
return cls.default_frontend_path()
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
@ -224,5 +226,5 @@ class FrontendManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Failed to initialize frontend: %s", e)
|
logging.error("Failed to initialize frontend: %s", e)
|
||||||
logging.info("Falling back to the default frontend.")
|
logging.info("Falling back to the default frontend.")
|
||||||
# check_frontend_version()
|
check_frontend_version()
|
||||||
return cls.default_frontend_path()
|
return cls.default_frontend_path()
|
||||||
|
|||||||
@ -85,7 +85,7 @@ class Hunyuan3Dv2(nn.Module):
|
|||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap1(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(img=args["img"],
|
out["img"], out["txt"] = block(img=args["img"],
|
||||||
txt=args["txt"],
|
txt=args["txt"],
|
||||||
@ -99,7 +99,7 @@ class Hunyuan3Dv2(nn.Module):
|
|||||||
"vec": vec,
|
"vec": vec,
|
||||||
"pe": pe,
|
"pe": pe,
|
||||||
"attn_mask": attn_mask},
|
"attn_mask": attn_mask},
|
||||||
{"original_block": block_wrap})
|
{"original_block": block_wrap1})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -22,11 +22,9 @@ else:
|
|||||||
sageattn = torch.nn.functional.scaled_dot_product_attention
|
sageattn = torch.nn.functional.scaled_dot_product_attention
|
||||||
|
|
||||||
if model_management.flash_attention_enabled():
|
if model_management.flash_attention_enabled():
|
||||||
try:
|
from flash_attn import flash_attn_func # pylint: disable=import-error
|
||||||
from flash_attn import flash_attn_func
|
else:
|
||||||
except ModuleNotFoundError:
|
flash_attn_func = torch.nn.functional.scaled_dot_product_attention
|
||||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
|
||||||
exit(-1)
|
|
||||||
|
|
||||||
from ...cli_args import args
|
from ...cli_args import args
|
||||||
from ... import ops
|
from ... import ops
|
||||||
@ -546,7 +544,7 @@ try:
|
|||||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) # pylint: disable=possibly-used-before-assignment,used-before-assignment
|
||||||
|
|
||||||
|
|
||||||
@flash_attn_wrapper.register_fake
|
@flash_attn_wrapper.register_fake
|
||||||
|
|||||||
@ -40,7 +40,7 @@ from .ldm.hunyuan_video.model import HunyuanVideo as HunyuanVideoModel
|
|||||||
from .ldm.hydit.models import HunYuanDiT
|
from .ldm.hydit.models import HunYuanDiT
|
||||||
from .ldm.lightricks.model import LTXVModel
|
from .ldm.lightricks.model import LTXVModel
|
||||||
from .ldm.lumina.model import NextDiT
|
from .ldm.lumina.model import NextDiT
|
||||||
from .ldm.hunyuan3d.model import Hunyuan3Dv2
|
from .ldm.hunyuan3d.model import Hunyuan3Dv2 as Hunyuan3Dv2Model
|
||||||
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
@ -64,7 +64,7 @@ class ModelType(Enum):
|
|||||||
IMG_TO_IMG = 9
|
IMG_TO_IMG = 9
|
||||||
|
|
||||||
|
|
||||||
from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux
|
from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux, IMG_TO_IMG
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
@ -94,7 +94,7 @@ def model_sampling(model_config, model_type):
|
|||||||
c = CONST
|
c = CONST
|
||||||
s = ModelSamplingFlux
|
s = ModelSamplingFlux
|
||||||
elif model_type == ModelType.IMG_TO_IMG:
|
elif model_type == ModelType.IMG_TO_IMG:
|
||||||
c = model_sampling.IMG_TO_IMG
|
c = IMG_TO_IMG
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -1081,7 +1081,7 @@ class WAN21(BaseModel):
|
|||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
|
|||||||
@ -1103,7 +1103,7 @@ def flash_attn_enabled():
|
|||||||
return False
|
return False
|
||||||
if directml_device:
|
if directml_device:
|
||||||
return False
|
return False
|
||||||
return FLASH_ATTENTION_ENABLED
|
return flash_attention_enabled()
|
||||||
|
|
||||||
|
|
||||||
def xformers_enabled_vae():
|
def xformers_enabled_vae():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user