diff --git a/comfy/ops.py b/comfy/ops.py index ca25693db..b5cd1d47e 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -928,6 +928,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec weight = state_dict.pop(weight_key, None) if weight is None: logging.warning(f"Missing weight for layer {layer_name}") + self.weight = None return manually_loaded_keys = [weight_key] @@ -1034,6 +1035,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec if self.bias is not None: sd["{}bias".format(prefix)] = self.bias + if self.weight is None: + return sd + if isinstance(self.weight, QuantizedTensor): sd_out = self.weight.state_dict("{}weight".format(prefix)) for k in sd_out: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d93926648..06f2fbf74 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -224,7 +224,7 @@ class Qwen3_8BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True - lm_head: bool = False + lm_head: bool = True stop_tokens = [151643, 151645] @dataclass @@ -912,6 +912,9 @@ class BaseGenerate: class BaseQwen3: def logits(self, x): input = x[:, -1:] + if self.model.config.lm_head: + return self.model.lm_head(input) + module = self.model.embed_tokens offload_stream = None