mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 16:32:34 +08:00
dynamic batch size
This commit is contained in:
parent
c8195ab3e5
commit
f380b89a71
@ -505,12 +505,11 @@ class AITemplateModelWrapper:
|
||||
args = args[:2]
|
||||
if kwargs.get("cond", None) is not None:
|
||||
encoder_hidden_states = kwargs.pop("cond")
|
||||
latent_model_input, timesteps = args
|
||||
encoder_hidden_states = encoder_hidden_states[0][0]
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states] * 2)
|
||||
latent_model_input, timesteps = args
|
||||
timesteps_pt = timesteps.expand(2)
|
||||
if latent_model_input.shape[0] < 2:
|
||||
latent_model_input = torch.cat([latent_model_input] * 2)
|
||||
timesteps_pt = torch.cat([timesteps] * 2)
|
||||
latent_model_input = torch.cat([latent_model_input] * 2)
|
||||
height = latent_model_input.shape[2]
|
||||
width = latent_model_input.shape[3]
|
||||
|
||||
@ -526,7 +525,7 @@ class AITemplateModelWrapper:
|
||||
num_outputs = len(self.unet_ait_exe.get_output_name_to_index_map())
|
||||
for i in range(num_outputs):
|
||||
shape = self.unet_ait_exe.get_output_maximum_shape(i)
|
||||
shape[0] = 2
|
||||
shape[0] = latent_model_input.shape[0]
|
||||
shape[1] = height
|
||||
shape[2] = width
|
||||
ys.append(torch.empty(shape).cuda().half())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user