diff --git a/comfy/model_base.py b/comfy/model_base.py index 7e86e76d0..5138d2b96 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -96,7 +96,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8) + fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) else: operations = model_config.custom_operations @@ -246,8 +246,8 @@ class BaseModel(torch.nn.Module): unet_state_dict = self.diffusion_model.state_dict() - if self.model_config.scaled_fp8: - unet_state_dict["scaled_fp8"] = torch.tensor([]) + if self.model_config.scaled_fp8 is not None: + unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3f720bce0..e1d29db3d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -288,8 +288,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) - if "{}scaled_fp8".format(unet_key_prefix) in state_dict: - model_config.scaled_fp8 = True + scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None) + if scaled_fp8_weight is not None: + model_config.scaled_fp8 = scaled_fp8_weight.dtype + if model_config.scaled_fp8 == torch.float32: + model_config.scaled_fp8 = torch.float8_e4m3fn return model_config diff --git a/comfy/ops.py b/comfy/ops.py index 05f7d3064..5e7c668eb 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -250,6 +250,12 @@ def fp8_linear(self, input): if dtype not in [torch.float8_e4m3fn]: return None + tensor_2d = False + if len(input.shape) == 2: + tensor_2d = True + input = input.unsqueeze(1) + + if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) w = w.t() @@ -272,7 +278,11 @@ def fp8_linear(self, input): if isinstance(o, tuple): o = o[0] + if tensor_2d: + return o.reshape(input.shape[0], -1) + return o.reshape((-1, input.shape[1], self.weight.shape[0])) + return None class fp8_ops(manual_cast): @@ -316,7 +326,11 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return out weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) + + if weight.numel() < input.numel(): #TODO: optimize + return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) + else: + return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) def convert_weight(self, weight, inplace=False, **kwargs): if inplace: @@ -334,10 +348,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return scaled_fp8_op -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) - if scaled_fp8: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True) + if scaled_fp8 is not None: + return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8) if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8: return fp8_ops diff --git a/comfy/samplers.py b/comfy/samplers.py index 1ecb41dda..f85bd203a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -358,8 +358,11 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6): ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps) sigs = [] + last_t = -1 for t in ts: - sigs += [float(model_sampling.sigmas[int(t)])] + if t != last_t: + sigs += [float(model_sampling.sigmas[int(t)])] + last_t = t sigs += [0.0] return torch.FloatTensor(sigs) diff --git a/comfy/sd.py b/comfy/sd.py index bcec48c03..e4abf0b94 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -579,7 +579,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return None unet_weight_dtype = list(model_config.supported_inference_dtypes) - if weight_dtype is not None: + if weight_dtype is not None and model_config.scaled_fp8 is None: unet_weight_dtype.append(weight_dtype) model_config.custom_operations = model_options.get("custom_operations", None) @@ -677,7 +677,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if weight_dtype is not None: + if weight_dtype is not None and model_config.scaled_fp8 is None: unet_weight_dtype.append(weight_dtype) if dtype is None: diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 0e69d16a0..54573abb1 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,7 +49,7 @@ class BASE: manual_cast_dtype = None custom_operations = None - scaled_fp8 = False + scaled_fp8 = None optimizations = {"fp8": False} @classmethod