dynamic batch size

This commit is contained in:
hlky 2023-05-16 08:21:50 +01:00
parent c8195ab3e5
commit f380b89a71

View File

@ -505,12 +505,11 @@ class AITemplateModelWrapper:
args = args[:2] args = args[:2]
if kwargs.get("cond", None) is not None: if kwargs.get("cond", None) is not None:
encoder_hidden_states = kwargs.pop("cond") encoder_hidden_states = kwargs.pop("cond")
latent_model_input, timesteps = args
encoder_hidden_states = encoder_hidden_states[0][0] encoder_hidden_states = encoder_hidden_states[0][0]
encoder_hidden_states = torch.cat([encoder_hidden_states] * 2) encoder_hidden_states = torch.cat([encoder_hidden_states] * 2)
latent_model_input, timesteps = args timesteps_pt = torch.cat([timesteps] * 2)
timesteps_pt = timesteps.expand(2) latent_model_input = torch.cat([latent_model_input] * 2)
if latent_model_input.shape[0] < 2:
latent_model_input = torch.cat([latent_model_input] * 2)
height = latent_model_input.shape[2] height = latent_model_input.shape[2]
width = latent_model_input.shape[3] width = latent_model_input.shape[3]
@ -526,7 +525,7 @@ class AITemplateModelWrapper:
num_outputs = len(self.unet_ait_exe.get_output_name_to_index_map()) num_outputs = len(self.unet_ait_exe.get_output_name_to_index_map())
for i in range(num_outputs): for i in range(num_outputs):
shape = self.unet_ait_exe.get_output_maximum_shape(i) shape = self.unet_ait_exe.get_output_maximum_shape(i)
shape[0] = 2 shape[0] = latent_model_input.shape[0]
shape[1] = height shape[1] = height
shape[2] = width shape[2] = width
ys.append(torch.empty(shape).cuda().half()) ys.append(torch.empty(shape).cuda().half())