llama: cast logits as a comfy-weight (#12248)

This is using a different layers weight with .to(). Change it to use
the ops caster if the original layer is a comfy weight so that it picks
up dynamic_vram and async_offload functionality in full.

Co-authored-by: Rattus <rattus128@gmail.com>
This commit is contained in:
comfyanonymous 2026-02-03 08:31:36 -08:00 committed by GitHub
parent affe881354
commit 223364743c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@ import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ops
import comfy.ldm.common_dit
import comfy.clip_model
@ -794,7 +795,19 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
self.dtype = dtype
def logits(self, x):
return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None)
input = x[:, -1:]
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):