Make conditioning broadcasting more robust, allows t.shape[0] to be any non-zero value

This commit is contained in:
James Walker 2023-08-30 18:58:53 +01:00
parent 18617967e5
commit c3f0e312c1

View File

@ -38,8 +38,10 @@ def broadcast_cond(cond, batch, device):
copy = []
for p in cond:
t = p[0]
if t.shape[0] < batch:
t = torch.cat([t] * batch)
if t.shape[0] != batch:
t_list = [t for _ in range(batch // t.shape[0])]
t_list.append(t[:batch % t.shape[0]])
t = torch.cat(t_list)
t = t.to(device)
copy += [[t] + p[1:]]
return copy