dynamic batch size

This commit is contained in:
hlky 2023-05-16 08:21:50 +01:00
parent c8195ab3e5
commit f380b89a71

View File

@ -505,12 +505,11 @@ class AITemplateModelWrapper:
args = args[:2]
if kwargs.get("cond", None) is not None:
encoder_hidden_states = kwargs.pop("cond")
latent_model_input, timesteps = args
encoder_hidden_states = encoder_hidden_states[0][0]
encoder_hidden_states = torch.cat([encoder_hidden_states] * 2)
latent_model_input, timesteps = args
timesteps_pt = timesteps.expand(2)
if latent_model_input.shape[0] < 2:
latent_model_input = torch.cat([latent_model_input] * 2)
timesteps_pt = torch.cat([timesteps] * 2)
latent_model_input = torch.cat([latent_model_input] * 2)
height = latent_model_input.shape[2]
width = latent_model_input.shape[3]
@ -526,7 +525,7 @@ class AITemplateModelWrapper:
num_outputs = len(self.unet_ait_exe.get_output_name_to_index_map())
for i in range(num_outputs):
shape = self.unet_ait_exe.get_output_maximum_shape(i)
shape[0] = 2
shape[0] = latent_model_input.shape[0]
shape[1] = height
shape[2] = width
ys.append(torch.empty(shape).cuda().half())