diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d93926648..06f2fbf74 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -224,7 +224,7 @@ class Qwen3_8BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True - lm_head: bool = False + lm_head: bool = True stop_tokens = [151643, 151645] @dataclass @@ -912,6 +912,9 @@ class BaseGenerate: class BaseQwen3: def logits(self, x): input = x[:, -1:] + if self.model.config.lm_head: + return self.model.lm_head(input) + module = self.model.embed_tokens offload_stream = None