mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 18:27:40 +08:00
Fix anima preprocess text embeds not using right inference dtype. (#12501)
This commit is contained in:
parent
18927538a1
commit
c39653163d
@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||||
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype_inference()
|
||||||
|
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
device = xc.device
|
device = xc.device
|
||||||
@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
|
|
||||||
|
def get_dtype_inference(self):
|
||||||
|
dtype = self.get_dtype()
|
||||||
|
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
return dtype
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
input_shapes += shape
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype_inference()
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
@ -1165,7 +1167,7 @@ class Anima(BaseModel):
|
|||||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||||
|
|
||||||
if torch.is_inference_mode_enabled(): # if not we are training
|
if torch.is_inference_mode_enabled(): # if not we are training
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
|
||||||
else:
|
else:
|
||||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user