Fix depending on asserts to raise an exception in BatchedBrownianTree and Flash attn module (#9884)

Correctly handle the case where w0 is passed by kwargs in BatchedBrownianTree
This commit is contained in:
blepping 2025-09-15 18:05:03 -06:00 committed by GitHub
parent 47a9cde5d3
commit 1a85483da1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 19 deletions

View File

@ -86,24 +86,24 @@ class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy.""" """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs): def __init__(self, x, t0, t1, seed=None, **kwargs):
self.cpu_tree = True self.cpu_tree = kwargs.pop("cpu", True)
if "cpu" in kwargs:
self.cpu_tree = kwargs.pop("cpu")
t0, t1, self.sign = self.sort(t0, t1) t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x)) w0 = kwargs.pop('w0', None)
if w0 is None:
w0 = torch.zeros_like(x)
self.batched = False
if seed is None: if seed is None:
seed = torch.randint(0, 2 ** 63 - 1, []).item() seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
self.batched = True elif isinstance(seed, (tuple, list)):
try: if len(seed) != x.shape[0]:
assert len(seed) == x.shape[0] raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
self.batched = True
w0 = w0[0] w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
if self.cpu_tree:
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
else: else:
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] seed = (seed,)
if self.cpu_tree:
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
@staticmethod @staticmethod
def sort(a, b): def sort(a, b):
@ -111,11 +111,10 @@ class BatchedBrownianTree:
def __call__(self, t0, t1): def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1) t0, t1, sign = self.sort(t0, t1)
device, dtype = t0.device, t0.dtype
if self.cpu_tree: if self.cpu_tree:
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
else: w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0] return w if self.batched else w[0]

View File

@ -600,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
try: try:
assert mask is None if mask is not None:
raise RuntimeError("Mask must not be set for Flash attention")
out = flash_attn_wrapper( out = flash_attn_wrapper(
q.transpose(1, 2), q.transpose(1, 2),
k.transpose(1, 2), k.transpose(1, 2),