mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Always set diffusion model to eval() mode. (#10331)
This commit is contained in:
parent
d68ece7301
commit
e693e4db6a
@ -138,6 +138,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
self.diffusion_model.eval()
|
||||||
if comfy.model_management.force_channels_last():
|
if comfy.model_management.force_channels_last():
|
||||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
logging.debug("using channels last mode for diffusion model")
|
logging.debug("using channels last mode for diffusion model")
|
||||||
@ -669,7 +670,6 @@ class Lotus(BaseModel):
|
|||||||
class StableCascade_C(BaseModel):
|
class StableCascade_C(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||||
self.diffusion_model.eval().requires_grad_(False)
|
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
@ -698,7 +698,6 @@ class StableCascade_C(BaseModel):
|
|||||||
class StableCascade_B(BaseModel):
|
class StableCascade_B(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||||
self.diffusion_model.eval().requires_grad_(False)
|
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user