Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-10-22 13:39:15 +03:00 committed by GitHub
commit f0e8767deb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 33 additions and 13 deletions

View File

@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8)
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
@ -246,8 +246,8 @@ class BaseModel(torch.nn.Module):
unet_state_dict = self.diffusion_model.state_dict()
if self.model_config.scaled_fp8:
unet_state_dict["scaled_fp8"] = torch.tensor([])
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)

View File

@ -288,8 +288,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)
if "{}scaled_fp8".format(unet_key_prefix) in state_dict:
model_config.scaled_fp8 = True
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
if scaled_fp8_weight is not None:
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
return model_config

View File

@ -250,6 +250,12 @@ def fp8_linear(self, input):
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
@ -272,7 +278,11 @@ def fp8_linear(self, input):
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input.shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return None
class fp8_ops(manual_cast):
@ -316,7 +326,11 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
@ -334,10 +348,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
return scaled_fp8_op
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops

View File

@ -358,8 +358,11 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
sigs = []
last_t = -1
for t in ts:
sigs += [float(model_sampling.sigmas[int(t)])]
if t != last_t:
sigs += [float(model_sampling.sigmas[int(t)])]
last_t = t
sigs += [0.0]
return torch.FloatTensor(sigs)

View File

@ -579,7 +579,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
@ -677,7 +677,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
if dtype is None:

View File

@ -49,7 +49,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = False
scaled_fp8 = None
optimizations = {"fp8": False}
@classmethod