mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
dynamic batch size
This commit is contained in:
parent
c8195ab3e5
commit
f380b89a71
@ -505,12 +505,11 @@ class AITemplateModelWrapper:
|
|||||||
args = args[:2]
|
args = args[:2]
|
||||||
if kwargs.get("cond", None) is not None:
|
if kwargs.get("cond", None) is not None:
|
||||||
encoder_hidden_states = kwargs.pop("cond")
|
encoder_hidden_states = kwargs.pop("cond")
|
||||||
|
latent_model_input, timesteps = args
|
||||||
encoder_hidden_states = encoder_hidden_states[0][0]
|
encoder_hidden_states = encoder_hidden_states[0][0]
|
||||||
encoder_hidden_states = torch.cat([encoder_hidden_states] * 2)
|
encoder_hidden_states = torch.cat([encoder_hidden_states] * 2)
|
||||||
latent_model_input, timesteps = args
|
timesteps_pt = torch.cat([timesteps] * 2)
|
||||||
timesteps_pt = timesteps.expand(2)
|
latent_model_input = torch.cat([latent_model_input] * 2)
|
||||||
if latent_model_input.shape[0] < 2:
|
|
||||||
latent_model_input = torch.cat([latent_model_input] * 2)
|
|
||||||
height = latent_model_input.shape[2]
|
height = latent_model_input.shape[2]
|
||||||
width = latent_model_input.shape[3]
|
width = latent_model_input.shape[3]
|
||||||
|
|
||||||
@ -526,7 +525,7 @@ class AITemplateModelWrapper:
|
|||||||
num_outputs = len(self.unet_ait_exe.get_output_name_to_index_map())
|
num_outputs = len(self.unet_ait_exe.get_output_name_to_index_map())
|
||||||
for i in range(num_outputs):
|
for i in range(num_outputs):
|
||||||
shape = self.unet_ait_exe.get_output_maximum_shape(i)
|
shape = self.unet_ait_exe.get_output_maximum_shape(i)
|
||||||
shape[0] = 2
|
shape[0] = latent_model_input.shape[0]
|
||||||
shape[1] = height
|
shape[1] = height
|
||||||
shape[2] = width
|
shape[2] = width
|
||||||
ys.append(torch.empty(shape).cuda().half())
|
ys.append(torch.empty(shape).cuda().half())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user