mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 04:10:15 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
febf8601dc
@ -178,7 +178,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = txt.clip(-65504, 65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x += mod.gate * output
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = x.clip(-65504, 65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -682,6 +682,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|||||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
||||||
for dt in supported_dtypes:
|
for dt in supported_dtypes:
|
||||||
if dt == torch.float16 and fp16_supported:
|
if dt == torch.float16 and fp16_supported:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
|||||||
21
comfy/ops.py
21
comfy/ops.py
@ -254,16 +254,33 @@ def fp8_linear(self, input):
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||||
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
|
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
|
||||||
|
|
||||||
|
scale_weight = self.scale_weight
|
||||||
|
scale_input = self.scale_input
|
||||||
|
if scale_weight is None:
|
||||||
|
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
if scale_input is None:
|
||||||
|
scale_input = scale_weight
|
||||||
|
if scale_input is None:
|
||||||
|
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight)
|
||||||
else:
|
else:
|
||||||
o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype)
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
|
if isinstance(o, tuple):
|
||||||
|
o = o[0]
|
||||||
|
|
||||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class fp8_ops(manual_cast):
|
class fp8_ops(manual_cast):
|
||||||
class Linear(manual_cast.Linear):
|
class Linear(manual_cast.Linear):
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.scale_weight = None
|
||||||
|
self.scale_input = None
|
||||||
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
out = fp8_linear(self, input)
|
out = fp8_linear(self, input)
|
||||||
if out is not None:
|
if out is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user