Lens: some cleanup (#14112)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

* Lens: remove redundant memory optimization
This commit is contained in:
Jukka Seppänen 2026-05-26 10:32:53 +03:00 committed by GitHub
parent 41812fa0ac
commit f9f54cae42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -141,7 +141,6 @@ class LensJointAttention(nn.Module):
img_q, img_k, img_v = img_qkv.unbind(dim=2) img_q, img_k, img_v = img_qkv.unbind(dim=2)
img_q = self.norm_q(img_q) img_q = self.norm_q(img_q)
img_k = self.norm_k(img_k) img_k = self.norm_k(img_k)
img_v = img_v.contiguous()
del img_qkv del img_qkv
# text stream # text stream
@ -149,8 +148,6 @@ class LensJointAttention(nn.Module):
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2) txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
txt_q = self.norm_added_q(txt_q) txt_q = self.norm_added_q(txt_q)
txt_k = self.norm_added_k(txt_k) txt_k = self.norm_added_k(txt_k)
txt_v = txt_v.contiguous()
del txt_qkv
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks # [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2) q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)