mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Fix depending on asserts to raise an exception in BatchedBrownianTree and Flash attn module (#9884)
Correctly handle the case where w0 is passed by kwargs in BatchedBrownianTree
This commit is contained in:
parent
47a9cde5d3
commit
1a85483da1
@ -86,24 +86,24 @@ class BatchedBrownianTree:
|
|||||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||||
|
|
||||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||||
self.cpu_tree = True
|
self.cpu_tree = kwargs.pop("cpu", True)
|
||||||
if "cpu" in kwargs:
|
|
||||||
self.cpu_tree = kwargs.pop("cpu")
|
|
||||||
t0, t1, self.sign = self.sort(t0, t1)
|
t0, t1, self.sign = self.sort(t0, t1)
|
||||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
w0 = kwargs.pop('w0', None)
|
||||||
if seed is None:
|
if w0 is None:
|
||||||
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
w0 = torch.zeros_like(x)
|
||||||
self.batched = True
|
|
||||||
try:
|
|
||||||
assert len(seed) == x.shape[0]
|
|
||||||
w0 = w0[0]
|
|
||||||
except TypeError:
|
|
||||||
seed = [seed]
|
|
||||||
self.batched = False
|
self.batched = False
|
||||||
if self.cpu_tree:
|
if seed is None:
|
||||||
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
|
||||||
|
elif isinstance(seed, (tuple, list)):
|
||||||
|
if len(seed) != x.shape[0]:
|
||||||
|
raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
|
||||||
|
self.batched = True
|
||||||
|
w0 = w0[0]
|
||||||
else:
|
else:
|
||||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
seed = (seed,)
|
||||||
|
if self.cpu_tree:
|
||||||
|
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
|
||||||
|
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sort(a, b):
|
def sort(a, b):
|
||||||
@ -111,11 +111,10 @@ class BatchedBrownianTree:
|
|||||||
|
|
||||||
def __call__(self, t0, t1):
|
def __call__(self, t0, t1):
|
||||||
t0, t1, sign = self.sort(t0, t1)
|
t0, t1, sign = self.sort(t0, t1)
|
||||||
|
device, dtype = t0.device, t0.dtype
|
||||||
if self.cpu_tree:
|
if self.cpu_tree:
|
||||||
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
|
||||||
else:
|
w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
|
||||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
|
||||||
|
|
||||||
return w if self.batched else w[0]
|
return w if self.batched else w[0]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -600,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert mask is None
|
if mask is not None:
|
||||||
|
raise RuntimeError("Mask must not be set for Flash attention")
|
||||||
out = flash_attn_wrapper(
|
out = flash_attn_wrapper(
|
||||||
q.transpose(1, 2),
|
q.transpose(1, 2),
|
||||||
k.transpose(1, 2),
|
k.transpose(1, 2),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user