From 223364743c35d6e1dc4f5ebc3796234d0f8484cf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 08:31:36 -0800 Subject: [PATCH] llama: cast logits as a comfy-weight (#12248) This is using a different layers weight with .to(). Change it to use the ops caster if the original layer is a comfy weight so that it picks up dynamic_vram and async_offload functionality in full. Co-authored-by: Rattus --- comfy/text_encoders/llama.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d2324ffc5..d1c628d20 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -6,6 +6,7 @@ import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management +import comfy.ops import comfy.ldm.common_dit import comfy.clip_model @@ -794,7 +795,19 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): self.dtype = dtype def logits(self, x): - return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None) + input = x[:, -1:] + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations):