mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
Compare commits
845 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ffdd53b327 | ||
|
|
65e2103b09 | ||
|
|
9304e47351 | ||
|
|
bc606d7d64 | ||
|
|
645ee1881e | ||
|
|
3d082c3206 | ||
|
|
683569de55 | ||
|
|
ea2c117bc3 | ||
|
|
fc4af86068 | ||
|
|
41bcf0619d | ||
|
|
d02d0e5744 | ||
|
|
70541d4e77 | ||
|
|
77b2f7c228 | ||
|
|
43e0d4e3cc | ||
|
|
dbd330454a | ||
|
|
33c7f1179d | ||
|
|
af91eb6c99 | ||
|
|
5cb1e0c9a0 | ||
|
|
51347f9fb8 | ||
|
|
a5e85017d8 | ||
|
|
5ac3b26a7d | ||
|
|
6592bffc60 | ||
|
|
971cefe7d4 | ||
|
|
da2bfb5b0a | ||
|
|
c5a47a1692 | ||
|
|
908fd7d749 | ||
|
|
5495589db3 | ||
|
|
982876d59a | ||
|
|
338d9ae3bb | ||
|
|
eeb020b9b7 | ||
|
|
ae65433a60 | ||
|
|
fdebe18296 | ||
|
|
f8321eb57b | ||
|
|
93948e3fc5 | ||
|
|
e711aaf1a7 | ||
|
|
57ddb7fd13 | ||
|
|
17c92a9f28 | ||
|
|
36357bbcc3 | ||
|
|
f668c2e3c9 | ||
|
|
fc657f471a | ||
|
|
791e30ff50 | ||
|
|
e2a800e7ef | ||
|
|
9d252f3b70 | ||
|
|
b9fb542703 | ||
|
|
cabc4d351f | ||
|
|
e136b6dbb0 | ||
|
|
d50f342c90 | ||
|
|
3b0368aa34 | ||
|
|
935493f6c1 | ||
|
|
60ee574748 | ||
|
|
8e889c535d | ||
|
|
fd271dedfd | ||
|
|
c3c6313fc7 | ||
|
|
85c4b4ae26 | ||
|
|
058f084371 | ||
|
|
ec7f65187d | ||
|
|
56fa7dbe38 | ||
|
|
329480da5a | ||
|
|
4086acf3c2 | ||
|
|
50ca97e776 | ||
|
|
7ac7d69d94 | ||
|
|
76f18e955d | ||
|
|
d7a0aef650 | ||
|
|
913f86b727 | ||
|
|
117bf3f2bd | ||
|
|
ae676ed105 | ||
|
|
fd109325db | ||
|
|
bed12674a1 | ||
|
|
092ee8a500 | ||
|
|
79d17ba233 | ||
|
|
6fd463aec9 | ||
|
|
43071e3de3 | ||
|
|
0ec05b1481 | ||
|
|
35fa091340 | ||
|
|
3c8456223c | ||
|
|
9bc893c5bb | ||
|
|
f4bdf5f830 | ||
|
|
6be85c7920 | ||
|
|
ea17add3c6 | ||
|
|
ecdc8697d5 | ||
|
|
dce518c2b4 | ||
|
|
440268d394 | ||
|
|
87c104bfc1 | ||
|
|
19f2192d69 | ||
|
|
519c941165 | ||
|
|
861817d22d | ||
|
|
c120eee5ba | ||
|
|
73f5649196 | ||
|
|
3f512f5659 | ||
|
|
b94d394a64 | ||
|
|
277237ccc1 | ||
|
|
daaceac769 | ||
|
|
33d6aec3b7 | ||
|
|
44baa0b7f3 | ||
|
|
a17cf1c387 | ||
|
|
b4a20acc54 | ||
|
|
c55dc857d5 | ||
|
|
878db3a727 | ||
|
|
30c259cac8 | ||
|
|
1cb7e22a95 | ||
|
|
2640acb31c | ||
|
|
7dbd5dfe91 | ||
|
|
f8b981ae9a | ||
|
|
4967f81778 | ||
|
|
0a6746898d | ||
|
|
5151cff293 | ||
|
|
af96d9812d | ||
|
|
52a32e2b32 | ||
|
|
b907085709 | ||
|
|
065a2fbbec | ||
|
|
0ff0457892 | ||
|
|
6484ac89dc | ||
|
|
f55c98a89f | ||
|
|
ca7808f240 | ||
|
|
52e778fff3 | ||
|
|
9d8a817985 | ||
|
|
b59750a86a | ||
|
|
3f382a4f98 | ||
|
|
f17251bec6 | ||
|
|
c38e7d6599 | ||
|
|
eaf68c9b5b | ||
|
|
cc6a8dcd1a | ||
|
|
a2d60aad0f | ||
|
|
d8433c63fd | ||
|
|
dd41b74549 | ||
|
|
55f654db3d | ||
|
|
58c6ed541d | ||
|
|
234c3dc85f | ||
|
|
8908ee2628 | ||
|
|
1105e0d139 | ||
|
|
8938aa3f30 | ||
|
|
f16219e3aa | ||
|
|
8402c8700a | ||
|
|
58b8574661 | ||
|
|
90b3995ec8 | ||
|
|
bdb10a583f | ||
|
|
0e24dbb19f | ||
|
|
e9aae31fa2 | ||
|
|
0c18842acb | ||
|
|
d196a905bb | ||
|
|
18b79acba9 | ||
|
|
dff996ca39 | ||
|
|
828b1b9953 | ||
|
|
af81cb962d | ||
|
|
5c7b08ca58 | ||
|
|
6b573ae0cb | ||
|
|
015a0599d0 | ||
|
|
acfaa5c4a1 | ||
|
|
b6805429b9 | ||
|
|
25022e0b09 | ||
|
|
22a2644e57 | ||
|
|
b2ef58e2b1 | ||
|
|
6a6d456c88 | ||
|
|
3d1fdaf9f4 | ||
|
|
1286fcfe40 | ||
|
|
3bd71554a2 | ||
|
|
f66183a541 | ||
|
|
cbd68e3d58 | ||
|
|
d89c29f259 | ||
|
|
a9c35256bc | ||
|
|
532938b16b | ||
|
|
ecb683b057 | ||
|
|
c55fd74816 | ||
|
|
3398123752 | ||
|
|
943b3b615d | ||
|
|
10e90a5757 | ||
|
|
b75d349f25 | ||
|
|
7b8389578e | ||
|
|
9e00ce5b76 | ||
|
|
f5e66d5e47 | ||
|
|
87b0359392 | ||
|
|
cb96d4d18c | ||
|
|
394348f5ca | ||
|
|
7601e89255 | ||
|
|
6a1d3a1ae1 | ||
|
|
65ee24c978 | ||
|
|
17027f2a6a | ||
|
|
b5c8be8b1d | ||
|
|
24fdb92edf | ||
|
|
d526974576 | ||
|
|
e1ab6bb394 | ||
|
|
048f49adbd | ||
|
|
47bfd5a33f | ||
|
|
fdf49a2861 | ||
|
|
f41e5f398d | ||
|
|
27cbac865e | ||
|
|
3d0003c24c | ||
|
|
7d6103325e | ||
|
|
2d4a08b717 | ||
|
|
9a02382568 | ||
|
|
bd01d9f7fd | ||
|
|
443056c401 | ||
|
|
f60923590c | ||
|
|
1ef328c007 | ||
|
|
94c298f962 | ||
|
|
2fde9597f4 | ||
|
|
f91078b1ff | ||
|
|
3b3ef9a77a | ||
|
|
8b0b93df51 | ||
|
|
1c7eaeca10 | ||
|
|
18e7d6dba5 | ||
|
|
e1d85e7577 | ||
|
|
1199411747 | ||
|
|
5ebcab3c7d | ||
|
|
c350009236 | ||
|
|
dea899f221 | ||
|
|
e632e5de28 | ||
|
|
2abd2b5c20 | ||
|
|
a1a70362ca | ||
|
|
cf97b033ee | ||
|
|
eb1c42f649 | ||
|
|
e05c907126 | ||
|
|
09dc24c8a9 | ||
|
|
1d69245981 | ||
|
|
97f198e421 | ||
|
|
bda0eb2448 | ||
|
|
c4a6b389de | ||
|
|
4cd881866b | ||
|
|
265adad858 | ||
|
|
7f3e4d486c | ||
|
|
a389ee01bb | ||
|
|
9c71a66790 | ||
|
|
af4b7b5edb | ||
|
|
0f4ef3afa0 | ||
|
|
6b88478f9f | ||
|
|
e199c8cc67 | ||
|
|
0652cb8e2d | ||
|
|
958a17199a | ||
|
|
e974e554ca | ||
|
|
4e2110c794 | ||
|
|
e617cddf24 | ||
|
|
1f3f7a2823 | ||
|
|
88df172790 | ||
|
|
6d6a18b0b7 | ||
|
|
97ff9fae7e | ||
|
|
135fa49ec2 | ||
|
|
44869ff786 | ||
|
|
20182a393f | ||
|
|
5f109fe6a0 | ||
|
|
c58c13b2ba | ||
|
|
7f374e42c8 | ||
|
|
27d1bd8829 | ||
|
|
614cf9805e | ||
|
|
513b0c46fb | ||
|
|
dfac94695b | ||
|
|
163b629c70 | ||
|
|
998bf60beb | ||
|
|
906c089957 | ||
|
|
25de7b1bfa | ||
|
|
ab7ab5be23 | ||
|
|
ec4fc2a09a | ||
|
|
1a58087ac2 | ||
|
|
6c14f3afac | ||
|
|
e525673f72 | ||
|
|
3fa7a5c04a | ||
|
|
210f7a1ba5 | ||
|
|
d202c2ba74 | ||
|
|
8817f8fc14 | ||
|
|
22e40d2ace | ||
|
|
3bea4efc6b | ||
|
|
8cf2ba4ba6 | ||
|
|
b61a40cbc9 | ||
|
|
f2bb3230b7 | ||
|
|
614b8d3345 | ||
|
|
6abc30aae9 | ||
|
|
55bad30375 | ||
|
|
c305deed56 | ||
|
|
601ee1775a | ||
|
|
c170fd2db5 | ||
|
|
9d529e5308 | ||
|
|
f6bbc1ac84 | ||
|
|
098a352f13 | ||
|
|
e86b79ab9e | ||
|
|
426cde37f1 | ||
|
|
dd5af0c587 | ||
|
|
388b306a2b | ||
|
|
24188b3141 | ||
|
|
1bcda6df98 | ||
|
|
a1864c01f2 | ||
|
|
4739d7717f | ||
|
|
f13cff0be6 | ||
|
|
9cdc64998f | ||
|
|
560b1bdfca | ||
|
|
b7992f871a | ||
|
|
2c2aa409b0 | ||
|
|
a4787ac83b | ||
|
|
b5c59b763c | ||
|
|
b4f30bd408 | ||
|
|
dad076aee6 | ||
|
|
0cf33953a7 | ||
|
|
5b80addafd | ||
|
|
9da397ea2f | ||
|
|
92d97380bd | ||
|
|
99ce2a1f66 | ||
|
|
b1467da480 | ||
|
|
d8d60b5609 | ||
|
|
b1293d50ef | ||
|
|
19b466160c | ||
|
|
bc0ad9bb49 | ||
|
|
4054b4bf38 | ||
|
|
55ac7d333c | ||
|
|
afa8a24fe1 | ||
|
|
493b81e48f | ||
|
|
6b035bfce2 | ||
|
|
74b7f0b04b | ||
|
|
f72c6616b2 | ||
|
|
1c10b33f9b | ||
|
|
ddfce1af4f | ||
|
|
7a883849ea | ||
|
|
84867067ea | ||
|
|
3374e900d0 | ||
|
|
51696e3fdc | ||
|
|
dfff7e5332 | ||
|
|
e4ea393666 | ||
|
|
c8674bc6e9 | ||
|
|
3dfdcf66b6 | ||
|
|
95ca2e56c8 | ||
|
|
27ffd12c45 | ||
|
|
e693e4db6a | ||
|
|
d68ece7301 | ||
|
|
894837de9a | ||
|
|
fdc92863b6 | ||
|
|
a125cd84b0 | ||
|
|
84e9ce32c6 | ||
|
|
f43b8ab2a2 | ||
|
|
14d642acd6 | ||
|
|
aa895db7e8 | ||
|
|
cdfc25a160 | ||
|
|
81e4dac107 | ||
|
|
90853fb9cd | ||
|
|
f1dd6e50f8 | ||
|
|
fc0fbf141c | ||
|
|
f3d5d328a3 | ||
|
|
139addd53c | ||
|
|
cbee7d3390 | ||
|
|
6732014a0a | ||
|
|
989f715d92 | ||
|
|
2ba8d7cce8 | ||
|
|
51fb505ffa | ||
|
|
72c2071972 | ||
|
|
6e59934089 | ||
|
|
3e0eb8d33f | ||
|
|
637221995f | ||
|
|
51697d50dc | ||
|
|
19f595b788 | ||
|
|
8a15568f10 | ||
|
|
9e984c48bc | ||
|
|
fc34c3d112 | ||
|
|
8aea746212 | ||
|
|
8c19910427 | ||
|
|
e77e0a8f8f | ||
|
|
a49007a7b0 | ||
|
|
6ae3515801 | ||
|
|
6bd3f8eb9f | ||
|
|
7326e46dee | ||
|
|
195e0b0639 | ||
|
|
187f43696d | ||
|
|
caf07331ff | ||
|
|
b1fa1922df | ||
|
|
2ed74f7ac7 | ||
|
|
22f99fb97e | ||
|
|
bbd683098e | ||
|
|
08726b64fe | ||
|
|
93d859cfaa | ||
|
|
4614ee09ca | ||
|
|
5c8e986e27 | ||
|
|
8c26d7bbe6 | ||
|
|
d7aa414141 | ||
|
|
3e68bc342c | ||
|
|
c2c5a7d5f8 | ||
|
|
8a293372ec | ||
|
|
ed3ca78e08 | ||
|
|
4ffea0e864 | ||
|
|
1395bce9f7 | ||
|
|
e9364ee279 | ||
|
|
f6e3e9a456 | ||
|
|
8f4ee9984c | ||
|
|
0e9d1724be | ||
|
|
4965c0e2ac | ||
|
|
911331c06c | ||
|
|
bb32d4ec31 | ||
|
|
a6f83a4a1a | ||
|
|
e4f99b479a | ||
|
|
d9c0a4053d | ||
|
|
11bab7be76 | ||
|
|
3af1881455 | ||
|
|
e0210ce0a7 | ||
|
|
7eb7160db4 | ||
|
|
638097829d | ||
|
|
c4a8cf60ab | ||
|
|
bab8ba20bf | ||
|
|
b682a73c55 | ||
|
|
631b9ae861 | ||
|
|
f48d7230de | ||
|
|
6e079abc3a | ||
|
|
977a4ed8c5 | ||
|
|
414a178fb6 | ||
|
|
447884b657 | ||
|
|
bed4b49d08 | ||
|
|
342cf644ce | ||
|
|
3758848423 | ||
|
|
0db6aabed3 | ||
|
|
1673ace19b | ||
|
|
7f38e4c538 | ||
|
|
8accf50908 | ||
|
|
ed0f4a609b | ||
|
|
041b8824f5 | ||
|
|
b1111c2062 | ||
|
|
05a258efd8 | ||
|
|
c8276f8c6b | ||
|
|
6ec1cfe101 | ||
|
|
b60dc31627 | ||
|
|
555f902fc1 | ||
|
|
1364548c72 | ||
|
|
2dadb34860 | ||
|
|
1cf86f5ae5 | ||
|
|
a1127b232d | ||
|
|
896f2e653c | ||
|
|
40ae495ddc | ||
|
|
653ceab414 | ||
|
|
160698eb41 | ||
|
|
7eca95657c | ||
|
|
ad5aef2d0c | ||
|
|
bcfd80dd79 | ||
|
|
6b4b671ce7 | ||
|
|
a9cf1cd249 | ||
|
|
255572188f | ||
|
|
0572029fee | ||
|
|
196954ab8c | ||
|
|
1e098d6132 | ||
|
|
cd66d72b46 | ||
|
|
2103e39335 | ||
|
|
d20576e6a3 | ||
|
|
a061b06321 | ||
|
|
80718908a9 | ||
|
|
7ea173c187 | ||
|
|
76eb1d72c3 | ||
|
|
c4a46e943c | ||
|
|
2b7f9a8196 | ||
|
|
ce4cb2389c | ||
|
|
c8d2117f02 | ||
|
|
fccab99ec0 | ||
|
|
fd79d32f38 | ||
|
|
341b4adefd | ||
|
|
b8730510db | ||
|
|
e808790799 | ||
|
|
145b0e4f79 | ||
|
|
707b2638ec | ||
|
|
8a5ac527e6 | ||
|
|
e3206351b0 | ||
|
|
1fee8827cb | ||
|
|
27bc181c49 | ||
|
|
d1d9eb94b1 | ||
|
|
7be2b49b6b | ||
|
|
9ed3c5cc09 | ||
|
|
66241cef31 | ||
|
|
e8df53b764 | ||
|
|
852704c81a | ||
|
|
9fdf8c25ab | ||
|
|
dc95b6acc0 | ||
|
|
711bcf33ee | ||
|
|
24b0fce099 | ||
|
|
1ea8c54064 | ||
|
|
8d6653fca6 | ||
|
|
dd611a7700 | ||
|
|
9288c78fc5 | ||
|
|
e42682b24e | ||
|
|
a39ac59c3e | ||
|
|
1a85483da1 | ||
|
|
47a9cde5d3 | ||
|
|
4f1f26ac6c | ||
|
|
f228367c5e | ||
|
|
80b7c9455b | ||
|
|
c1297f4eb3 | ||
|
|
e5e70636e7 | ||
|
|
29bf807b0e | ||
|
|
2559dee492 | ||
|
|
a3b04de700 | ||
|
|
d7f40442f9 | ||
|
|
b149e2e1e3 | ||
|
|
581bae2af3 | ||
|
|
af99928f22 | ||
|
|
53c9c7d39a | ||
|
|
ba68e83f1c | ||
|
|
dcb8834983 | ||
|
|
f9d2e4b742 | ||
|
|
45bc1f5c00 | ||
|
|
0aa074a420 | ||
|
|
7757d5a657 | ||
|
|
e600520f8a | ||
|
|
fd2b820ec2 | ||
|
|
d6b977b2e6 | ||
|
|
15ec9ea958 | ||
|
|
33bd9ed9cb | ||
|
|
18de0b2830 | ||
|
|
df6850fae8 | ||
|
|
e01e99d075 | ||
|
|
72212fef66 | ||
|
|
df34f1549a | ||
|
|
9b0553809c | ||
|
|
8d7c930246 | ||
|
|
de44b95db6 | ||
|
|
543888d3d8 | ||
|
|
70fc0425b3 | ||
|
|
85e34643f8 | ||
|
|
5c33872e2f | ||
|
|
206595f854 | ||
|
|
b288fb0db8 | ||
|
|
f73b176abd | ||
|
|
103a12cb66 | ||
|
|
97652d26b8 | ||
|
|
bd1d9bcd5f | ||
|
|
fb763d4333 | ||
|
|
bcbd7884e3 | ||
|
|
27a0fcccc3 | ||
|
|
ea6cdd2631 | ||
|
|
2ee7879a0b | ||
|
|
3493b9cb1f | ||
|
|
c9ebe70072 | ||
|
|
261421e218 | ||
|
|
a9f1bb10a5 | ||
|
|
b0338e930b | ||
|
|
b71f9bcb71 | ||
|
|
72855db715 | ||
|
|
f48d05a2d1 | ||
|
|
4368d8f87f | ||
|
|
22da0a83e9 | ||
|
|
50333f1715 | ||
|
|
26d5b86da8 | ||
|
|
4f5812b937 | ||
|
|
1bcb469089 | ||
|
|
464ba1d614 | ||
|
|
e3018c2a5a | ||
|
|
3412d53b1d | ||
|
|
e2d1e5dad9 | ||
|
|
27e067ce50 | ||
|
|
9b15155972 | ||
|
|
32a627bf1f | ||
|
|
fe442fac2e | ||
|
|
d2c502e629 | ||
|
|
fea9ea8268 | ||
|
|
f949094b3c | ||
|
|
4449e14769 | ||
|
|
885015eecf | ||
|
|
a86aaa4301 | ||
|
|
2efb2cbc38 | ||
|
|
15aa9222c4 | ||
|
|
c7bb3e2bce | ||
|
|
e80a14ad50 | ||
|
|
d28b39d93d | ||
|
|
1c184c29eb | ||
|
|
edde0b5043 | ||
|
|
0063610177 | ||
|
|
ce0052c087 | ||
|
|
0eb821a7b6 | ||
|
|
4aa79dbf2c | ||
|
|
38f697d953 | ||
|
|
3aad339b63 | ||
|
|
491755325c | ||
|
|
496888fd68 | ||
|
|
b5ac6ed7ce | ||
|
|
b20ba1f27c | ||
|
|
31a37686d0 | ||
|
|
88aee596a3 | ||
|
|
6a193ac557 | ||
|
|
47f4db3e84 | ||
|
|
5352abc6d3 | ||
|
|
39aa06bd5d | ||
|
|
914c2a2973 | ||
|
|
e633a47ad1 | ||
|
|
f6b93d41a0 | ||
|
|
95ac7794b7 | ||
|
|
71ed4a399e | ||
|
|
3e316c6338 | ||
|
|
8be0d22ab7 | ||
|
|
59eddda900 | ||
|
|
41048c69b4 | ||
|
|
fc247150fe | ||
|
|
fe31ad0276 | ||
|
|
ca4e96a8ae | ||
|
|
050c67323c | ||
|
|
497d41fb50 | ||
|
|
ff57793659 | ||
|
|
f7bd5e58dd | ||
|
|
7ed73d12d1 | ||
|
|
eb39019daa | ||
|
|
bab08f40d1 | ||
|
|
bc49106837 | ||
|
|
1b2de2642d | ||
|
|
9fa1036f60 | ||
|
|
0737b7e0d2 | ||
|
|
0963493a9c | ||
|
|
e73a9dbe30 | ||
|
|
fe01885acf | ||
|
|
7139d6d93f | ||
|
|
2f52e8f05f | ||
|
|
8d38ea3bbf | ||
|
|
5a8f502db5 | ||
|
|
7cd2c4bd6a | ||
|
|
dfa791eb4b | ||
|
|
bddd69618b | ||
|
|
54d8fdbed0 | ||
|
|
d844d8b13b | ||
|
|
07a927517c | ||
|
|
f16a70ba67 | ||
|
|
36b5127fd3 | ||
|
|
4977f203fa | ||
|
|
bd2ab73976 | ||
|
|
da2efeaec6 | ||
|
|
7f3b9b16c6 | ||
|
|
d4e353a94e | ||
|
|
ed43784b0d | ||
|
|
0f2b8525bc | ||
|
|
20a84166d0 | ||
|
|
ed2e33c69a | ||
|
|
1702e6df16 | ||
|
|
c308a8840a | ||
|
|
027c63f63a | ||
|
|
e08ecfbd8a | ||
|
|
4e5c230f6a | ||
|
|
f0d5d0111f | ||
|
|
ad19a069f6 | ||
|
|
5d65d6753b | ||
|
|
deebee4ff6 | ||
|
|
fa570cbf59 | ||
|
|
644b23ac0b | ||
|
|
72fd4d22b6 | ||
|
|
e4f7ea105f | ||
|
|
c991a5da65 | ||
|
|
9df8792d4b | ||
|
|
3da5a07510 | ||
|
|
afa0a45206 | ||
|
|
615eb52049 | ||
|
|
d5c1954d5c | ||
|
|
e400f26c8f | ||
|
|
5ca8e2fac3 | ||
|
|
3294782d19 | ||
|
|
898d88e10e | ||
|
|
560d38f34c | ||
|
|
e1d4f36d8d | ||
|
|
1e3ae1eed8 | ||
|
|
f4231a80b1 | ||
|
|
2208aa616d | ||
|
|
629b173837 | ||
|
|
fa340add55 | ||
|
|
966f3a5206 | ||
|
|
0552de7c7d | ||
|
|
5828607ccf | ||
|
|
735bb4bdb1 | ||
|
|
bf2a1b5b1e | ||
|
|
42974a448c | ||
|
|
05df2df489 | ||
|
|
37d620a6b8 | ||
|
|
32691b16f4 | ||
|
|
4c3e57b0ae | ||
|
|
9126c0cfe4 | ||
|
|
d8c51ba15a | ||
|
|
32a95bba8a | ||
|
|
da1ad9b516 | ||
|
|
d044a24398 | ||
|
|
5be6fd09ff | ||
|
|
f69609bbd6 | ||
|
|
c012400240 | ||
|
|
03895dea7c | ||
|
|
84f9759424 | ||
|
|
7991341e89 | ||
|
|
140ffc7fdc | ||
|
|
182f90b5ec | ||
|
|
aebac22193 | ||
|
|
13aaa66ec2 | ||
|
|
5f582a9757 | ||
|
|
fbcc23945d | ||
|
|
3dfefc88d0 | ||
|
|
bff60b5cfc | ||
|
|
1e638a140b | ||
|
|
4696d74305 | ||
|
|
5ee381c058 | ||
|
|
4887743a2a | ||
|
|
97b8a2c26a | ||
|
|
97eb256a35 | ||
|
|
61b08d4ba6 | ||
|
|
da9dab7edd | ||
|
|
d2aaef029c | ||
|
|
0a3d062e06 | ||
|
|
2f74e17975 | ||
|
|
dca6bdd4fa | ||
|
|
7d593baf91 | ||
|
|
c60dc4177c | ||
|
|
5d4cc3ba1b | ||
|
|
9f1388c0a3 | ||
|
|
a88788dce6 | ||
|
|
d0210fe2e5 | ||
|
|
e6d9f62744 | ||
|
|
78672d0ee6 | ||
|
|
1ef70fcde4 | ||
|
|
0621d73a9c | ||
|
|
b850d9a8bb | ||
|
|
c60467a148 | ||
|
|
c0207b473f | ||
|
|
93bc2f8e4d | ||
|
|
e6e5d33b35 | ||
|
|
4293e4da21 | ||
|
|
69cb57b342 | ||
|
|
d03ae077b4 | ||
|
|
0ccc88b03f | ||
|
|
eb2f78b4e0 | ||
|
|
e729a5cc11 | ||
|
|
e78d230496 | ||
|
|
d3504e1778 | ||
|
|
a86a58c308 | ||
|
|
39dda1d40d | ||
|
|
5ad33787de | ||
|
|
255f139863 | ||
|
|
5ac9ec214b | ||
|
|
0aa1c58b04 | ||
|
|
5249e45a1c | ||
|
|
54a45b9967 | ||
|
|
9a470e073e | ||
|
|
7d627f764c | ||
|
|
a0c0785635 | ||
|
|
100c2478ea | ||
|
|
1da5639e86 | ||
|
|
1b96fae1d4 | ||
|
|
7f492522b6 | ||
|
|
650838fd6f | ||
|
|
491fafbd64 | ||
|
|
9bc2798f72 | ||
|
|
50afba747c | ||
|
|
6b8062f414 | ||
|
|
b1ae4126c3 | ||
|
|
9dabda19f0 | ||
|
|
543c24108c | ||
|
|
260a5ca5d9 | ||
|
|
861c3bbb3d | ||
|
|
9ca581c941 | ||
|
|
4831e9c2c4 | ||
|
|
480375f349 | ||
|
|
b40143984c | ||
|
|
b43916a134 | ||
|
|
7bc7dd2aa2 | ||
|
|
938d3e8216 | ||
|
|
8f05fb48ea | ||
|
|
b7ff5bd14d | ||
|
|
2b653e8c18 | ||
|
|
1fd306824d | ||
|
|
1205afc708 | ||
|
|
5612670ee4 | ||
|
|
181a9bf26d | ||
|
|
aac10ad23a | ||
|
|
974254218a | ||
|
|
c5de4955bb | ||
|
|
9fd0cd7cf7 | ||
|
|
b5e97db9ac | ||
|
|
1359c969e4 | ||
|
|
059cd38aa2 | ||
|
|
e740dfd806 | ||
|
|
7eab7d2944 | ||
|
|
75d327abd5 | ||
|
|
ee615ac269 | ||
|
|
27870ec3c3 | ||
|
|
f41f323c52 | ||
|
|
f74fc4d927 | ||
|
|
ae26cd99b5 | ||
|
|
e9af97ba1a | ||
|
|
d9277301d2 | ||
|
|
34c8eeec06 | ||
|
|
9f1069290c | ||
|
|
111f583e00 | ||
|
|
79ed752748 | ||
|
|
772de7c006 | ||
|
|
b22e97dcfa | ||
|
|
f02de13316 | ||
|
|
c46268bf60 | ||
|
|
cf49a2c5b5 | ||
|
|
170c7bb90c | ||
|
|
2a0b138feb | ||
|
|
e195c1b13f | ||
|
|
5b4eb021cb | ||
|
|
396454fa41 | ||
|
|
a3cf272522 | ||
|
|
ba9548f756 | ||
|
|
e18f53cca9 | ||
|
|
c36be0ea09 | ||
|
|
9093301a49 | ||
|
|
bd951a714f | ||
|
|
6493709d6a | ||
|
|
b976f934ae | ||
|
|
7d8cf4cacc | ||
|
|
68f4496b8e | ||
|
|
ef5266b1c1 | ||
|
|
a96e65df18 | ||
|
|
93a49a45de | ||
|
|
ec70ed6aea | ||
|
|
7a13f74220 | ||
|
|
8042eb20c6 | ||
|
|
bd9f166c12 | ||
|
|
dd94416db2 | ||
|
|
ae0e7c4dff | ||
|
|
78f79266a9 | ||
|
|
1883e70b43 | ||
|
|
31ca603ccb | ||
|
|
f7fb193712 | ||
|
|
7e9267fa77 | ||
|
|
91d40086db | ||
|
|
5b12b55e32 | ||
|
|
e9e9a031a8 | ||
|
|
d7430c529a | ||
|
|
cd88f709ab | ||
|
|
4459a17e82 | ||
|
|
483b3e62e0 | ||
|
|
8e81c507d2 | ||
|
|
e1c6dc720e | ||
|
|
7ea79ebb9d | ||
|
|
ae75a084df | ||
|
|
d6a2137fc3 | ||
|
|
53e8d8193c | ||
|
|
29596bd53f | ||
|
|
803af1e0c3 | ||
|
|
6673939e76 | ||
|
|
f74778e75d | ||
|
|
520eb77b72 | ||
|
|
5bf69bde35 | ||
|
|
c69af655aa | ||
|
|
251f54a2ad | ||
|
|
c6529c0d77 | ||
|
|
baa8c8cdd3 | ||
|
|
40fd39c7cb | ||
|
|
4d1c4b9797 | ||
|
|
d2566eb4b2 | ||
|
|
ef7e885fe4 | ||
|
|
ecb8d15e7a | ||
|
|
365f9ed157 | ||
|
|
50c605e957 | ||
|
|
9685d4f3c3 | ||
|
|
8a4ff747bd | ||
|
|
af1eb58be8 | ||
|
|
373a9386a4 | ||
|
|
6e28a46454 | ||
|
|
c7b25784b1 | ||
|
|
7f800d04fa | ||
|
|
97755eed46 | ||
|
|
daf9d25ee2 | ||
|
|
3b4b171e18 | ||
|
|
d8759c772b | ||
|
|
4248b1618f |
@ -53,6 +53,16 @@ try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash") # noqa: T201
|
||||
except:
|
||||
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
||||
repo.state_cleanup()
|
||||
repo.index.read_tree(repo.head.peel().tree)
|
||||
repo.index.write()
|
||||
try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash.") # noqa: T201
|
||||
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||
try:
|
||||
@ -66,8 +76,10 @@ if branch is None:
|
||||
try:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
except:
|
||||
print("pulling.") # noqa: T201
|
||||
pull(repo)
|
||||
print("fetching.") # noqa: T201
|
||||
for remote in repo.remotes:
|
||||
if remote.name == "origin":
|
||||
remote.fetch()
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
repo.checkout(ref)
|
||||
branch = repo.lookup_branch('master')
|
||||
@ -149,3 +161,4 @@ try:
|
||||
shutil.copy(stable_update_script, stable_update_script_to)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
28
.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt
Executable file
28
.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt
Executable file
@ -0,0 +1,28 @@
|
||||
As of the time of writing this you need this driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
If you have a AMD gpu:
|
||||
|
||||
run_amd_gpu.bat
|
||||
|
||||
If you have memory issues you can try disabling the smart memory management by running comfyui with:
|
||||
|
||||
run_amd_gpu_disable_smart_memory.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
You can download the stable diffusion XL one from: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors
|
||||
|
||||
|
||||
RECOMMENDED WAY TO UPDATE:
|
||||
To update the ComfyUI code: update\update_comfyui.bat
|
||||
|
||||
|
||||
TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
|
||||
In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
|
||||
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
|
||||
pause
|
||||
@ -4,6 +4,9 @@ if you have a NVIDIA gpu:
|
||||
|
||||
run_nvidia_gpu.bat
|
||||
|
||||
if you want to enable the fast fp16 accumulation (faster for fp16 models with slightly less quality):
|
||||
|
||||
run_nvidia_gpu_fast_fp16_accumulation.bat
|
||||
|
||||
|
||||
To run it in slow CPU mode:
|
||||
@ -0,0 +1,3 @@
|
||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
3
.ci/windows_nvidia_base_files/run_nvidia_gpu.bat
Executable file
3
.ci/windows_nvidia_base_files/run_nvidia_gpu.bat
Executable file
@ -0,0 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
@ -0,0 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -1,2 +1,3 @@
|
||||
/web/assets/** linguist-generated
|
||||
/web/** linguist-vendored
|
||||
comfy_api_nodes/apis/__init__.py linguist-generated
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
10
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -8,13 +8,15 @@ body:
|
||||
Before submitting a **Bug Report**, please ensure the following:
|
||||
|
||||
- **1:** You are running the latest version of ComfyUI.
|
||||
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
||||
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
|
||||
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
||||
`--disable-all-custom-nodes` command line argument.
|
||||
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
|
||||
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
## Very Important
|
||||
|
||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
@ -22,7 +24,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@ -18,7 +18,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
|
||||
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
<!-- API_NODE_PR_CHECKLIST: do not remove -->
|
||||
|
||||
## API Node PR Checklist
|
||||
|
||||
### Scope
|
||||
- [ ] **Is API Node Change**
|
||||
|
||||
### Pricing & Billing
|
||||
- [ ] **Need pricing update**
|
||||
- [ ] **No pricing update**
|
||||
|
||||
If **Need pricing update**:
|
||||
- [ ] Metronome rate cards updated
|
||||
- [ ] Auto‑billing tests updated and passing
|
||||
|
||||
### QA
|
||||
- [ ] **QA done**
|
||||
- [ ] **QA not required**
|
||||
|
||||
### Comms
|
||||
- [ ] Informed **Kosinkadink**
|
||||
58
.github/workflows/api-node-template.yml
vendored
Normal file
58
.github/workflows/api-node-template.yml
vendored
Normal file
@ -0,0 +1,58 @@
|
||||
name: Append API Node PR template
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, reopened, synchronize, ready_for_review]
|
||||
paths:
|
||||
- 'comfy_api_nodes/**' # only run if these files changed
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
inject:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Ensure template exists and append to PR body
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const { owner, repo } = context.repo;
|
||||
const number = context.payload.pull_request.number;
|
||||
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
|
||||
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
|
||||
|
||||
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
|
||||
|
||||
let templateText;
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner,
|
||||
repo,
|
||||
path: templatePath,
|
||||
ref: pr.base.ref
|
||||
});
|
||||
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
|
||||
templateText = buf.toString('utf8');
|
||||
} catch (e) {
|
||||
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Enforce the presence of the marker inside the template (for idempotence)
|
||||
if (!templateText.includes(marker)) {
|
||||
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// If the PR already contains the marker, do not append again.
|
||||
const body = pr.body || '';
|
||||
if (body.includes(marker)) {
|
||||
core.info('Template already present in PR body; nothing to inject.');
|
||||
return;
|
||||
}
|
||||
|
||||
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
|
||||
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
|
||||
core.notice('API Node template appended to PR description.');
|
||||
40
.github/workflows/check-line-endings.yml
vendored
Normal file
40
.github/workflows/check-line-endings.yml
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
name: Check for Windows Line Endings
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*'] # Trigger on all pull requests to any branch
|
||||
|
||||
jobs:
|
||||
check-line-endings:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history to compare changes
|
||||
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
CRLF_FOUND=false
|
||||
|
||||
# Loop through each changed file
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# Check if the file exists and is a text file
|
||||
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
|
||||
# Check for CRLF line endings
|
||||
if grep -UP '\r$' "$FILE"; then
|
||||
echo "Error: Windows line endings (CRLF) detected in $FILE"
|
||||
CRLF_FOUND=true
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Exit with error if CRLF was found
|
||||
if [ "$CRLF_FOUND" = true ]; then
|
||||
exit 1
|
||||
fi
|
||||
78
.github/workflows/release-stable-all.yml
vendored
Normal file
78
.github/workflows/release-stable-all.yml
vendored
Normal file
@ -0,0 +1,78 @@
|
||||
name: "Release Stable All Portable Versions"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
release_nvidia_default:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA Default (cu130)"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "9"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu128:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu128"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu128"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu128"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu126:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu126"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu126"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu126"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_amd_rocm:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm711"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
rel_extra_name: ""
|
||||
test_release: false
|
||||
secrets: inherit
|
||||
108
.github/workflows/release-webhook.yml
vendored
Normal file
108
.github/workflows/release-webhook.yml
vendored
Normal file
@ -0,0 +1,108 @@
|
||||
name: Release Webhook
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
||||
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
||||
run: |
|
||||
# Generate UUID for delivery ID
|
||||
DELIVERY_ID=$(uuidgen)
|
||||
HOOK_ID="release-webhook-$(date +%s)"
|
||||
|
||||
# Create webhook payload matching GitHub release webhook format
|
||||
PAYLOAD=$(cat <<EOF
|
||||
{
|
||||
"action": "published",
|
||||
"release": {
|
||||
"id": ${{ github.event.release.id }},
|
||||
"node_id": "${{ github.event.release.node_id }}",
|
||||
"url": "${{ github.event.release.url }}",
|
||||
"html_url": "${{ github.event.release.html_url }}",
|
||||
"assets_url": "${{ github.event.release.assets_url }}",
|
||||
"upload_url": "${{ github.event.release.upload_url }}",
|
||||
"tag_name": "${{ github.event.release.tag_name }}",
|
||||
"target_commitish": "${{ github.event.release.target_commitish }}",
|
||||
"name": ${{ toJSON(github.event.release.name) }},
|
||||
"body": ${{ toJSON(github.event.release.body) }},
|
||||
"draft": ${{ github.event.release.draft }},
|
||||
"prerelease": ${{ github.event.release.prerelease }},
|
||||
"created_at": "${{ github.event.release.created_at }}",
|
||||
"published_at": "${{ github.event.release.published_at }}",
|
||||
"author": {
|
||||
"login": "${{ github.event.release.author.login }}",
|
||||
"id": ${{ github.event.release.author.id }},
|
||||
"node_id": "${{ github.event.release.author.node_id }}",
|
||||
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
||||
"url": "${{ github.event.release.author.url }}",
|
||||
"html_url": "${{ github.event.release.author.html_url }}",
|
||||
"type": "${{ github.event.release.author.type }}",
|
||||
"site_admin": ${{ github.event.release.author.site_admin }}
|
||||
},
|
||||
"tarball_url": "${{ github.event.release.tarball_url }}",
|
||||
"zipball_url": "${{ github.event.release.zipball_url }}",
|
||||
"assets": ${{ toJSON(github.event.release.assets) }}
|
||||
},
|
||||
"repository": {
|
||||
"id": ${{ github.event.repository.id }},
|
||||
"node_id": "${{ github.event.repository.node_id }}",
|
||||
"name": "${{ github.event.repository.name }}",
|
||||
"full_name": "${{ github.event.repository.full_name }}",
|
||||
"private": ${{ github.event.repository.private }},
|
||||
"owner": {
|
||||
"login": "${{ github.event.repository.owner.login }}",
|
||||
"id": ${{ github.event.repository.owner.id }},
|
||||
"node_id": "${{ github.event.repository.owner.node_id }}",
|
||||
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
||||
"url": "${{ github.event.repository.owner.url }}",
|
||||
"html_url": "${{ github.event.repository.owner.html_url }}",
|
||||
"type": "${{ github.event.repository.owner.type }}",
|
||||
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
||||
},
|
||||
"html_url": "${{ github.event.repository.html_url }}",
|
||||
"clone_url": "${{ github.event.repository.clone_url }}",
|
||||
"git_url": "${{ github.event.repository.git_url }}",
|
||||
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
||||
"url": "${{ github.event.repository.url }}",
|
||||
"created_at": "${{ github.event.repository.created_at }}",
|
||||
"updated_at": "${{ github.event.repository.updated_at }}",
|
||||
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
||||
"default_branch": "${{ github.event.repository.default_branch }}",
|
||||
"fork": ${{ github.event.repository.fork }}
|
||||
},
|
||||
"sender": {
|
||||
"login": "${{ github.event.sender.login }}",
|
||||
"id": ${{ github.event.sender.id }},
|
||||
"node_id": "${{ github.event.sender.node_id }}",
|
||||
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
||||
"url": "${{ github.event.sender.url }}",
|
||||
"html_url": "${{ github.event.sender.html_url }}",
|
||||
"type": "${{ github.event.sender.type }}",
|
||||
"site_admin": ${{ github.event.sender.site_admin }}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
# Generate HMAC-SHA256 signature
|
||||
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
||||
|
||||
# Send webhook with required headers
|
||||
curl -X POST "$WEBHOOK_URL" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-GitHub-Event: release" \
|
||||
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
||||
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
||||
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
||||
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
||||
-d "$PAYLOAD" \
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
25
.github/workflows/ruff.yml
vendored
25
.github/workflows/ruff.yml
vendored
@ -21,3 +21,28 @@ jobs:
|
||||
|
||||
- name: Run Ruff
|
||||
run: ruff check .
|
||||
|
||||
pylint:
|
||||
name: Run Pylint
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Install Pylint
|
||||
run: pip install pylint
|
||||
|
||||
- name: Run Pylint
|
||||
run: pylint comfy_api_nodes
|
||||
|
||||
110
.github/workflows/stable-release.yml
vendored
110
.github/workflows/stable-release.yml
vendored
@ -2,28 +2,78 @@
|
||||
name: "Release Stable Version"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "13"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
rel_name:
|
||||
description: 'Release name'
|
||||
required: true
|
||||
type: string
|
||||
default: "nvidia"
|
||||
rel_extra_name:
|
||||
description: 'Release extra name'
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
test_release:
|
||||
description: 'Test Release'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_tag:
|
||||
description: 'Git tag'
|
||||
required: true
|
||||
type: string
|
||||
cu:
|
||||
description: 'CUDA version'
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "cu129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
|
||||
default: "6"
|
||||
rel_name:
|
||||
description: 'Release name'
|
||||
required: true
|
||||
type: string
|
||||
default: "nvidia"
|
||||
rel_extra_name:
|
||||
description: 'Release extra name'
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
test_release:
|
||||
description: 'Test Release'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
jobs:
|
||||
package_comfy_windows:
|
||||
@ -42,15 +92,15 @@ jobs:
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
cu${{ inputs.cu }}_python_deps.tar
|
||||
${{ inputs.cache_tag }}_python_deps.tar
|
||||
update_comfyui_and_python_dependencies.bat
|
||||
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
|
||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
||||
- shell: bash
|
||||
run: |
|
||||
mv cu${{ inputs.cu }}_python_deps.tar ../
|
||||
mv ${{ inputs.cache_tag }}_python_deps.tar ../
|
||||
mv update_comfyui_and_python_dependencies.bat ../
|
||||
cd ..
|
||||
tar xf cu${{ inputs.cu }}_python_deps.tar
|
||||
tar xf ${{ inputs.cache_tag }}_python_deps.tar
|
||||
pwd
|
||||
ls
|
||||
|
||||
@ -65,9 +115,21 @@ jobs:
|
||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
cd ..
|
||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
||||
|
||||
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
||||
rm requirements_comfyui.txt
|
||||
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
if test -f ./Lib/site-packages/torch/lib/dnnl.lib; then
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
@ -80,14 +142,18 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
||||
|
||||
- shell: bash
|
||||
if: ${{ inputs.test_release }}
|
||||
run: |
|
||||
cd ..
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
@ -96,11 +162,9 @@ jobs:
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ inputs.git_tag }}
|
||||
overwrite: true
|
||||
prerelease: true
|
||||
make_latest: false
|
||||
files: ComfyUI_windows_portable_${{ inputs.rel_name }}${{ inputs.rel_extra_name }}.7z
|
||||
tag_name: ${{ inputs.git_tag }}
|
||||
draft: true
|
||||
overwrite_files: true
|
||||
|
||||
21
.github/workflows/test-ci.yml
vendored
21
.github/workflows/test-ci.yml
vendored
@ -5,6 +5,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
paths-ignore:
|
||||
- 'app/**'
|
||||
- 'input/**'
|
||||
@ -21,14 +22,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
@ -73,14 +75,15 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.11"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
|
||||
30
.github/workflows/test-execution.yml
vendored
Normal file
30
.github/workflows/test-execution.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install -r tests-unit/requirements.txt
|
||||
- name: Run Execution Tests
|
||||
run: |
|
||||
python -m pytest tests/execution -v --skip-timing-checks
|
||||
4
.github/workflows/test-launch.yml
vendored
4
.github/workflows/test-launch.yml
vendored
@ -2,9 +2,9 @@ name: Test server launches without errors
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
||||
6
.github/workflows/test-unit.yml
vendored
6
.github/workflows/test-unit.yml
vendored
@ -2,15 +2,15 @@ name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
os: [ubuntu-latest, windows-2022, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
|
||||
1
.github/workflows/update-version.yml
vendored
1
.github/workflows/update-version.yml
vendored
@ -6,6 +6,7 @@ on:
|
||||
- "pyproject.toml"
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
|
||||
@ -17,19 +17,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "130"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -56,7 +56,8 @@ jobs:
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
|
||||
64
.github/workflows/windows_release_dependencies_manual.yml
vendored
Normal file
64
.github/workflows/windows_release_dependencies_manual.yml
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
name: "Windows Release dependencies Manual"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
torch_dependencies:
|
||||
description: 'torch dependencies'
|
||||
required: false
|
||||
type: string
|
||||
default: "torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128"
|
||||
cache_tag:
|
||||
description: 'Cached dependencies tag'
|
||||
required: true
|
||||
type: string
|
||||
default: "cu128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
|
||||
jobs:
|
||||
build_dependencies:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
||||
|
||||
- shell: bash
|
||||
run: |
|
||||
echo "@echo off
|
||||
call update_comfyui.bat nopause
|
||||
echo -
|
||||
echo This will try to update pytorch and all python dependencies.
|
||||
echo -
|
||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||
echo -
|
||||
pause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade ${{ inputs.torch_dependencies }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
grep -v comfyui requirements.txt > requirements_nocomfyui.txt
|
||||
python -m pip wheel --no-cache-dir ${{ inputs.torch_dependencies }} -r requirements_nocomfyui.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
mv temp_wheel_dir ${{ inputs.cache_tag }}_python_deps
|
||||
tar cf ${{ inputs.cache_tag }}_python_deps.tar ${{ inputs.cache_tag }}_python_deps
|
||||
|
||||
- uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
${{ inputs.cache_tag }}_python_deps.tar
|
||||
update_comfyui_and_python_dependencies.bat
|
||||
key: ${{ runner.os }}-build-${{ inputs.cache_tag }}-${{ inputs.python_minor }}
|
||||
@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "2"
|
||||
default: "5"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -53,6 +53,8 @@ jobs:
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
@ -66,7 +68,7 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||
|
||||
echo "call update_comfyui.bat nopause
|
||||
|
||||
14
.github/workflows/windows_release_package.yml
vendored
14
.github/workflows/windows_release_package.yml
vendored
@ -7,19 +7,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -64,6 +64,10 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
@ -77,12 +81,12 @@ jobs:
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/windows_nvidia_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
||||
24
CODEOWNERS
24
CODEOWNERS
@ -1,24 +1,2 @@
|
||||
# Admins
|
||||
* @comfyanonymous
|
||||
|
||||
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
* @comfyanonymous @kosinkadink @guill
|
||||
|
||||
168
QUANTIZATION.md
Normal file
168
QUANTIZATION.md
Normal file
@ -0,0 +1,168 @@
|
||||
# The Comfy guide to Quantization
|
||||
|
||||
|
||||
## How does quantization work?
|
||||
|
||||
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
||||
|
||||
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
||||
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
||||
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
||||
|
||||
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
||||
|
||||
```
|
||||
absmax = max(abs(tensor))
|
||||
scale = amax / max_dynamic_range_low_precision
|
||||
|
||||
# Quantization
|
||||
tensor_q = (tensor / scale).to(low_precision_dtype)
|
||||
|
||||
# De-Quantization
|
||||
tensor_dq = tensor_q.to(fp16) * scale
|
||||
|
||||
tensor_dq ~ tensor
|
||||
```
|
||||
|
||||
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
||||
|
||||
|
||||
## Quantization in Comfy
|
||||
|
||||
```
|
||||
QuantizedTensor (torch.Tensor subclass)
|
||||
↓ __torch_dispatch__
|
||||
Two-Level Registry (generic + layout handlers)
|
||||
↓
|
||||
MixedPrecisionOps + Metadata Detection
|
||||
```
|
||||
|
||||
### Representation
|
||||
|
||||
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
||||
|
||||
A `Layout` class defines how a specific quantization format behaves:
|
||||
- Required parameters
|
||||
- Quantize method
|
||||
- De-Quantize method
|
||||
|
||||
```python
|
||||
from comfy.quant_ops import QuantizedLayout
|
||||
|
||||
class MyLayout(QuantizedLayout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs):
|
||||
# Convert to quantized format
|
||||
qdata = ...
|
||||
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
return qdata.to(orig_dtype) * scale
|
||||
```
|
||||
|
||||
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
||||
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
||||
|
||||
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
||||
```python
|
||||
from comfy.quant_ops import register_layout_op
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
||||
def my_linear(func, args, kwargs):
|
||||
# Extract tensors, call optimized kernel
|
||||
...
|
||||
```
|
||||
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
||||
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
||||
|
||||
|
||||
### Mixed Precision
|
||||
|
||||
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
||||
|
||||
**Architecture:**
|
||||
|
||||
```python
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {} # Maps layer names to quantization configs
|
||||
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
||||
```
|
||||
|
||||
**Key mechanism:**
|
||||
|
||||
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
||||
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
||||
- If the layer name **is** in `_layer_quant_config`:
|
||||
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
||||
- Load associated quantization parameters (scales, block_size, etc.)
|
||||
|
||||
**Why it's needed:**
|
||||
|
||||
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
||||
|
||||
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
||||
|
||||
|
||||
## Checkpoint Format
|
||||
|
||||
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
||||
|
||||
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
||||
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
||||
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
||||
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
||||
|
||||
### Scaling Parameters details
|
||||
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
||||
- **weight_scale**: quantization scalers for the weights
|
||||
- **weight_scale_2**: global scalers in the context of double scaling
|
||||
- **pre_quant_scale**: scalers used for smoothing salient weights
|
||||
- **input_scale**: quantization scalers for the activations
|
||||
|
||||
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||
|
||||
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||
|
||||
### Quantization Metadata
|
||||
|
||||
The metadata stored alongside the checkpoint contains:
|
||||
- **format_version**: String to define a version of the standard
|
||||
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"_quantization_metadata": {
|
||||
"format_version": "1.0",
|
||||
"layers": {
|
||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Creating Quantized Checkpoints
|
||||
|
||||
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
||||
|
||||
### Weight Quantization
|
||||
|
||||
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
||||
|
||||
### Calibration (for Activation Quantization)
|
||||
|
||||
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
||||
|
||||
1. **Collect statistics**: Run inference on N representative samples
|
||||
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||
|
||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||
148
README.md
148
README.md
@ -6,6 +6,7 @@
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
||||
[![Twitter][twitter-shield]][twitter-url]
|
||||
[![Matrix][matrix-shield]][matrix-url]
|
||||
<br>
|
||||
[![][github-release-shield]][github-release-link]
|
||||
@ -20,6 +21,8 @@
|
||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||
[discord-url]: https://www.comfy.org/discord
|
||||
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
|
||||
[twitter-url]: https://x.com/ComfyUI
|
||||
|
||||
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
||||
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
@ -36,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
## Get Started
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
|
||||
#### [Windows Portable Package](#installing)
|
||||
@ -52,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Image Models
|
||||
- SD1.x, SD2.x,
|
||||
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||
@ -62,13 +65,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
@ -76,9 +89,10 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
|
||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
|
||||
- Safe loading of ckpt, pt, pth, etc.. files.
|
||||
- Embeddings/Textual inversion
|
||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||
@ -89,12 +103,10 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
||||
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
@ -103,10 +115,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
@ -163,18 +176,24 @@ There is a portable standalone build for Windows that should work for running on
|
||||
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
||||
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
## Jupyter Notebook
|
||||
|
||||
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
|
||||
|
||||
|
||||
## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
|
||||
|
||||
@ -186,7 +205,11 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
### Instructions:
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@ -195,48 +218,54 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
Put your VAE in: models/vae
|
||||
|
||||
|
||||
### AMD GPUs (Linux only)
|
||||
### AMD GPUs (Linux)
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
|
||||
These have less hardware support than the builds above but they work on windows. You also need to install the pytorch version specific to your hardware.
|
||||
|
||||
RDNA 3 (RX 7000 series):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
|
||||
|
||||
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx1151/```
|
||||
|
||||
RDNA 4 (RX 9000 series):
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/```
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch nightly, use the following command:
|
||||
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
|
||||
|
||||
This is the command to install the Pytorch xpu nightly which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
||||
|
||||
```
|
||||
conda install libuv
|
||||
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||
```
|
||||
|
||||
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
@ -267,10 +296,6 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
|
||||
|
||||
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
||||
|
||||
#### DirectML (AMD Cards on Windows)
|
||||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
#### Ascend NPUs
|
||||
|
||||
For models compatible with Ascend Extension for PyTorch (torch_npu). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
@ -288,6 +313,39 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
|
||||
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
||||
3. Launch ComfyUI by running `python main.py`
|
||||
|
||||
#### Iluvatar Corex
|
||||
|
||||
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
|
||||
|
||||
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install the manager dependencies:
|
||||
```bash
|
||||
pip install -r manager_requirements.txt
|
||||
```
|
||||
|
||||
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
|
||||
```bash
|
||||
python main.py --enable-manager
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
@ -338,7 +396,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a
|
||||
|
||||
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
|
||||
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
84
alembic.ini
Normal file
84
alembic.ini
Normal file
@ -0,0 +1,84 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = alembic_db
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic_db/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
# version_path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = sqlite:///user/comfyui.db
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
4
alembic_db/README.md
Normal file
4
alembic_db/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
## Generate new revision
|
||||
|
||||
1. Update models in `/app/database/models.py`
|
||||
2. Run `alembic revision --autogenerate -m "{your message}"`
|
||||
64
alembic_db/env.py
Normal file
64
alembic_db/env.py
Normal file
@ -0,0 +1,64 @@
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
alembic_db/script.py.mako
Normal file
28
alembic_db/script.py.mako
Normal file
@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@ -58,8 +58,13 @@ class InternalRoutes:
|
||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
|
||||
def is_visible_file(entry: os.DirEntry) -> bool:
|
||||
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
||||
return entry.is_file() and not entry.name.startswith('.')
|
||||
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
||||
112
app/database/db.py
Normal file
112
app/database/db.py
Normal file
@ -0,0 +1,112 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
|
||||
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
{get_missing_requirements_message()}
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
return _DB_AVAILABLE
|
||||
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
config = get_alembic_config()
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
||||
14
app/database/models.py
Normal file
14
app/database/models.py
Normal file
@ -0,0 +1,14 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
|
||||
# TODO: Define models here
|
||||
@ -10,46 +10,70 @@ import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from typing import Dict, TypedDict, Optional
|
||||
from aiohttp import web
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
|
||||
# The path to the requirements.txt file
|
||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def frontend_install_warning_message():
|
||||
"""The warning message to display when the frontend version is not up to date."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {req_path}
|
||||
{get_missing_requirements_message()}
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
||||
""".strip()
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
|
||||
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
||||
return bool(re.match(pattern, version))
|
||||
|
||||
def get_installed_frontend_version():
|
||||
"""Get the currently installed frontend package version."""
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
return frontend_version_str
|
||||
|
||||
|
||||
def get_required_frontend_version():
|
||||
"""Get the required frontend version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-frontend-package=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-frontend-package not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_frontend_version():
|
||||
"""Check if the frontend version is up to date."""
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
return tuple(map(int, version.split(".")))
|
||||
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version_str = get_installed_frontend_version()
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
required_frontend_str = get_required_frontend_version()
|
||||
required_frontend = parse_version(required_frontend_str)
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
@ -121,9 +145,22 @@ class FrontEndProvider:
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
return response.json()
|
||||
|
||||
@cached_property
|
||||
def latest_prerelease(self) -> Release:
|
||||
"""Get the latest pre-release version - even if it's older than the latest release"""
|
||||
release = [release for release in self.all_releases if release["prerelease"]]
|
||||
|
||||
if not release:
|
||||
raise ValueError("No pre-releases found")
|
||||
|
||||
# GitHub returns releases in reverse chronological order, so first is latest
|
||||
return release[0]
|
||||
|
||||
def get_release(self, version: str) -> Release:
|
||||
if version == "latest":
|
||||
return self.latest_release
|
||||
elif version == "prerelease":
|
||||
return self.latest_prerelease
|
||||
else:
|
||||
for release in self.all_releases:
|
||||
if release["tag_name"] in [version, f"v{version}"]:
|
||||
@ -164,6 +201,42 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
class FrontendManager:
|
||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||
|
||||
@classmethod
|
||||
def get_required_frontend_version(cls) -> str:
|
||||
"""Get the required frontend package version."""
|
||||
return get_required_frontend_version()
|
||||
|
||||
@classmethod
|
||||
def get_installed_templates_version(cls) -> str:
|
||||
"""Get the currently installed workflow templates package version."""
|
||||
try:
|
||||
templates_version_str = version("comfyui-workflow-templates")
|
||||
return templates_version_str
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_required_templates_version(cls) -> str:
|
||||
"""Get the required workflow templates version from requirements.txt."""
|
||||
try:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("comfyui-workflow-templates=="):
|
||||
version_str = line.split("==")[-1]
|
||||
if not is_valid_version(version_str):
|
||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
||||
return None
|
||||
return version_str
|
||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading requirements.txt: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def default_frontend_path(cls) -> str:
|
||||
try:
|
||||
@ -185,7 +258,54 @@ comfyui-frontend-package is not installed.
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
def template_asset_map(cls) -> Optional[Dict[str, str]]:
|
||||
"""Return a mapping of template asset names to their absolute paths."""
|
||||
try:
|
||||
from comfyui_workflow_templates import (
|
||||
get_asset_path,
|
||||
iter_templates,
|
||||
)
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-workflow-templates is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
template_entries = list(iter_templates())
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to enumerate workflow templates: {exc}")
|
||||
return None
|
||||
|
||||
asset_map: Dict[str, str] = {}
|
||||
try:
|
||||
for entry in template_entries:
|
||||
for asset in entry.assets:
|
||||
asset_map[asset.filename] = get_asset_path(
|
||||
entry.template_id, asset.filename
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to resolve template asset paths: {exc}")
|
||||
return None
|
||||
|
||||
if not asset_map:
|
||||
logging.error("No workflow template assets found. Did the packages install correctly?")
|
||||
return None
|
||||
|
||||
return asset_map
|
||||
|
||||
|
||||
@classmethod
|
||||
def legacy_templates_path(cls) -> Optional[str]:
|
||||
"""Return the legacy templates directory shipped inside the meta package."""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@ -204,6 +324,7 @@ comfyui-workflow-templates is not installed.
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
@ -230,7 +351,7 @@ comfyui-workflow-templates is not installed.
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
"""
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
|
||||
match_result = re.match(VERSION_PATTERN, value)
|
||||
if match_result is None:
|
||||
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
||||
@ -320,3 +441,17 @@ comfyui-workflow-templates is not installed.
|
||||
logging.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
@classmethod
|
||||
def template_asset_handler(cls):
|
||||
assets = cls.template_asset_map()
|
||||
if not assets:
|
||||
return None
|
||||
|
||||
async def serve_template(request: web.Request) -> web.StreamResponse:
|
||||
rel_path = request.match_info.get("path", "")
|
||||
target = assets.get(rel_path)
|
||||
if target is None:
|
||||
raise web.HTTPNotFound()
|
||||
return web.FileResponse(target)
|
||||
|
||||
return serve_template
|
||||
|
||||
@ -130,10 +130,21 @@ class ModelFileManager:
|
||||
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
full_path = os.path.join(dirpath, file_name)
|
||||
relative_path = os.path.relpath(full_path, directory)
|
||||
|
||||
# Get file metadata
|
||||
file_info = {
|
||||
"name": relative_path,
|
||||
"pathIndex": pathIndex,
|
||||
"modified": os.path.getmtime(full_path), # Add modification time
|
||||
"created": os.path.getctime(full_path), # Add creation time
|
||||
"size": os.path.getsize(full_path) # Add file size
|
||||
}
|
||||
result.append(file_info)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
@ -144,7 +155,7 @@ class ModelFileManager:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
|
||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||
return result, dirs, time.perf_counter()
|
||||
|
||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
112
app/subgraph_manager.py
Normal file
112
app/subgraph_manager.py
Normal file
@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
import hashlib
|
||||
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
"""
|
||||
Source of subgraph - custom_nodes vs templates.
|
||||
"""
|
||||
path: str
|
||||
"""
|
||||
Relative path of the subgraph file.
|
||||
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
||||
"""
|
||||
name: str
|
||||
"""
|
||||
Name of subgraph file.
|
||||
"""
|
||||
info: CustomNodeSubgraphEntryInfo
|
||||
"""
|
||||
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||
"""
|
||||
data: str
|
||||
|
||||
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
node_pack: str
|
||||
"""Node pack name."""
|
||||
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
||||
if entry is None:
|
||||
return None
|
||||
entry = entry.copy()
|
||||
entry.pop('path', None)
|
||||
if remove_data:
|
||||
entry.pop('data', None)
|
||||
return entry
|
||||
|
||||
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
||||
entries = entries.copy()
|
||||
for key in list(entries.keys()):
|
||||
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
# if not forced to reload and cached, return cache
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
# Load subgraphs from custom nodes
|
||||
subfolder = "subgraphs"
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
for file in matched_files:
|
||||
# replace backslashes with forward slashes
|
||||
file = file.replace('\\', '/')
|
||||
info: CustomNodeSubgraphEntryInfo = {
|
||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||
}
|
||||
source = Source.custom_node
|
||||
# hash source + path to make sure id will be as unique as possible, but
|
||||
# reproducible across backend reloads
|
||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": Source.custom_node,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": info,
|
||||
}
|
||||
subgraphs_dict[id] = entry
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||
if entry is not None and entry.get('data', None) is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||
# that's the reasoning for the current implementation
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
@ -20,13 +20,15 @@ class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
created: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
"modified": os.path.getmtime(path),
|
||||
"created": os.path.getctime(path)
|
||||
}
|
||||
|
||||
|
||||
@ -57,6 +59,9 @@ class UserManager():
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
user = request.headers["comfy-user"]
|
||||
# Block System Users (use same error message to prevent probing)
|
||||
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise KeyError("Unknown user: " + user)
|
||||
|
||||
if user not in self.users:
|
||||
raise KeyError("Unknown user: " + user)
|
||||
@ -64,15 +69,16 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
root_dir = folder_paths.get_user_directory()
|
||||
else:
|
||||
raise KeyError("Unknown filepath type:" + type)
|
||||
|
||||
user = self.get_request_user_id(request)
|
||||
path = user_root = os.path.abspath(os.path.join(root_dir, user))
|
||||
user_root = folder_paths.get_public_user_directory(user)
|
||||
if user_root is None:
|
||||
return None
|
||||
path = user_root
|
||||
|
||||
# prevent leaving /{type}
|
||||
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
||||
@ -99,7 +105,11 @@ class UserManager():
|
||||
name = name.strip()
|
||||
if not name:
|
||||
raise ValueError("username not provided")
|
||||
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
|
||||
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = user_id + "_" + str(uuid.uuid4())
|
||||
|
||||
self.users[user_id] = name
|
||||
@ -130,7 +140,10 @@ class UserManager():
|
||||
if username in self.users.values():
|
||||
return web.json_response({"error": "Duplicate username."}, status=400)
|
||||
|
||||
user_id = self.add_user(username)
|
||||
try:
|
||||
user_id = self.add_user(username)
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
return web.json_response(user_id)
|
||||
|
||||
@routes.get("/userdata")
|
||||
@ -361,10 +374,17 @@ class UserManager():
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
status=400,
|
||||
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
||||
)
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
@ -415,7 +435,7 @@ class UserManager():
|
||||
return source
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
if not isinstance(dest, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
|
||||
91
comfy/audio_encoders/audio_encoders.py
Normal file
91
comfy/audio_encoders/audio_encoders.py
Normal file
@ -0,0 +1,91 @@
|
||||
from .wav2vec2 import Wav2Vec2Model
|
||||
from .whisper import WhisperLargeV3
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
import logging
|
||||
import torchaudio
|
||||
|
||||
|
||||
class AudioEncoderModel():
|
||||
def __init__(self, config):
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
model_type = config.pop("model_type")
|
||||
model_config = dict(config)
|
||||
model_config.update({
|
||||
"dtype": self.dtype,
|
||||
"device": offload_device,
|
||||
"operations": comfy.ops.manual_cast
|
||||
})
|
||||
|
||||
if model_type == "wav2vec2":
|
||||
self.model = Wav2Vec2Model(**model_config)
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_audio(self, audio, sample_rate):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
|
||||
out, all_layers = self.model(audio.to(self.load_device))
|
||||
outputs = {}
|
||||
outputs["encoded_audio"] = out
|
||||
outputs["encoded_audio_all_layers"] = all_layers
|
||||
outputs["audio_samples"] = audio.shape[2]
|
||||
return outputs
|
||||
|
||||
|
||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||
if "encoder.layer_norm.bias" in sd: #wav2vec2
|
||||
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
|
||||
if embed_dim == 1024:# large
|
||||
config = {
|
||||
"model_type": "wav2vec2",
|
||||
"embed_dim": 1024,
|
||||
"num_heads": 16,
|
||||
"num_layers": 24,
|
||||
"conv_norm": True,
|
||||
"conv_bias": True,
|
||||
"do_normalize": True,
|
||||
"do_stable_layer_norm": True
|
||||
}
|
||||
elif embed_dim == 768: # base
|
||||
config = {
|
||||
"model_type": "wav2vec2",
|
||||
"embed_dim": 768,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"conv_norm": False,
|
||||
"conv_bias": False,
|
||||
"do_normalize": False, # chinese-wav2vec2-base has this False
|
||||
"do_stable_layer_norm": False
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
|
||||
elif "model.encoder.embed_positions.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
|
||||
config = {
|
||||
"model_type": "whisper3",
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("ERROR: audio encoder not supported.")
|
||||
|
||||
audio_encoder = AudioEncoderModel(config)
|
||||
m, u = audio_encoder.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing audio encoder: {}".format(m))
|
||||
if len(u) > 0:
|
||||
logging.warning("unexpected audio encoder: {}".format(u))
|
||||
|
||||
return audio_encoder
|
||||
252
comfy/audio_encoders/wav2vec2.py
Normal file
252
comfy/audio_encoders/wav2vec2.py
Normal file
@ -0,0 +1,252 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
|
||||
|
||||
class LayerNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||
|
||||
class LayerGroupNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x))
|
||||
|
||||
class ConvNoNorm(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class ConvFeatureEncoder(nn.Module):
|
||||
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if conv_norm:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
else:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
class FeatureProjection(nn.Module):
|
||||
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer_norm(x)
|
||||
x = self.projection(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionalConvEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=groups,
|
||||
)
|
||||
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||
self.activation = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, 2)
|
||||
x = self.conv(x)[:, :, :-1]
|
||||
x = self.activation(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.pos_conv_embed(x)
|
||||
all_x = ()
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x, mask)
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
assert (mask is None) # TODO?
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
out = optimized_attention_masked(q, k, v, self.num_heads)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
|
||||
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.intermediate_dense(x)
|
||||
x = torch.nn.functional.gelu(x)
|
||||
x = self.output_dense(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
residual = x
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
x = self.attention(x, mask=mask)
|
||||
x = residual + x
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
return self.final_layer_norm(x + self.feed_forward(x))
|
||||
else:
|
||||
return x + self.feed_forward(self.final_layer_norm(x))
|
||||
|
||||
|
||||
class Wav2Vec2Model(nn.Module):
|
||||
"""Complete Wav2Vec 2.0 model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=1024,
|
||||
final_dim=256,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
conv_norm=True,
|
||||
conv_bias=True,
|
||||
do_normalize=True,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
conv_dim = 512
|
||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
self.encoder = TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||
x = torch.mean(x, dim=1)
|
||||
|
||||
if self.do_normalize:
|
||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||
|
||||
features = self.feature_extractor(x)
|
||||
features = self.feature_projection(features)
|
||||
batch_size, seq_len, _ = features.shape
|
||||
|
||||
x, all_x = self.encoder(features)
|
||||
return x, all_x
|
||||
186
comfy/audio_encoders/whisper.py
Executable file
186
comfy/audio_encoders/whisper.py
Executable file
@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from typing import Optional
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.ops
|
||||
|
||||
class WhisperFeatureExtractor(nn.Module):
|
||||
def __init__(self, n_mels=128, device=None):
|
||||
super().__init__()
|
||||
self.sample_rate = 16000
|
||||
self.n_fft = 400
|
||||
self.hop_length = 160
|
||||
self.n_mels = n_mels
|
||||
self.chunk_length = 30
|
||||
self.n_samples = 480000
|
||||
|
||||
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
).to(device)
|
||||
|
||||
def __call__(self, audio):
|
||||
audio = torch.mean(audio, dim=1)
|
||||
batch_size = audio.shape[0]
|
||||
processed_audio = []
|
||||
|
||||
for i in range(batch_size):
|
||||
aud = audio[i]
|
||||
if aud.shape[0] > self.n_samples:
|
||||
aud = aud[:self.n_samples]
|
||||
elif aud.shape[0] < self.n_samples:
|
||||
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
|
||||
processed_audio.append(aud)
|
||||
|
||||
audio = torch.stack(processed_audio)
|
||||
|
||||
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
|
||||
|
||||
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
|
||||
log_mel_spec = (log_mel_spec + 4.0) / 4.0
|
||||
|
||||
return log_mel_spec
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert d_model % n_heads == 0
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.d_k = d_model // n_heads
|
||||
|
||||
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = query.shape
|
||||
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
|
||||
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
|
||||
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||
|
||||
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
|
||||
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
|
||||
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x = self.self_attn(x, x, x, attention_mask)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = F.gelu(x)
|
||||
x = self.fc2(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int = 128,
|
||||
n_ctx: int = 1500,
|
||||
n_state: int = 1280,
|
||||
n_head: int = 20,
|
||||
n_layer: int = 32,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
|
||||
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
|
||||
|
||||
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(n_layer)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
|
||||
|
||||
all_x = ()
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x)
|
||||
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
|
||||
class WhisperLargeV3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int = 128,
|
||||
n_audio_ctx: int = 1500,
|
||||
n_audio_state: int = 1280,
|
||||
n_audio_head: int = 20,
|
||||
n_audio_layer: int = 32,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
|
||||
|
||||
self.encoder = AudioEncoder(
|
||||
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, audio):
|
||||
mel = self.feature_extractor(audio)
|
||||
x, all_x = self.encoder(mel)
|
||||
return x, all_x
|
||||
@ -413,7 +413,8 @@ class ControlNet(nn.Module):
|
||||
out_middle = []
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
if y is None:
|
||||
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
|
||||
@ -49,7 +49,8 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||
@ -96,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
|
||||
Latent2RGB = "latent2rgb"
|
||||
TAESD = "taesd"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str):
|
||||
for member in cls:
|
||||
if member.value == value:
|
||||
return member
|
||||
return None
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
@ -104,6 +112,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -119,6 +128,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
|
||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||
|
||||
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
@ -129,7 +144,10 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
@ -140,10 +158,14 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
@ -151,13 +173,15 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
@ -203,6 +227,11 @@ parser.add_argument(
|
||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||
)
|
||||
|
||||
database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
||||
@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
all_intermediate = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
if intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
|
||||
x = l(x, mask, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
@ -97,7 +107,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
|
||||
if embeds is not None:
|
||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||
else:
|
||||
|
||||
@ -50,7 +50,13 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
||||
model_type = config.get("model_type", "clip_vision_model")
|
||||
model_class = IMAGE_ENCODERS.get(model_type)
|
||||
if model_type == "siglip_vision_model":
|
||||
self.return_all_hidden_states = True
|
||||
else:
|
||||
self.return_all_hidden_states = False
|
||||
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
@ -68,12 +74,18 @@ class ClipVisionModel():
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
if self.return_all_hidden_states:
|
||||
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||
outputs["all_hidden_states"] = all_hs
|
||||
else:
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
|
||||
outputs["mm_projected"] = out[3]
|
||||
return outputs
|
||||
|
||||
@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
|
||||
# Dinov2
|
||||
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -37,6 +37,8 @@ class IO(StrEnum):
|
||||
CONTROL_NET = "CONTROL_NET"
|
||||
VAE = "VAE"
|
||||
MODEL = "MODEL"
|
||||
LORA_MODEL = "LORA_MODEL"
|
||||
LOSS_MAP = "LOSS_MAP"
|
||||
CLIP_VISION = "CLIP_VISION"
|
||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||
STYLE_MODEL = "STYLE_MODEL"
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import math
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
@ -10,12 +11,15 @@ class CONDRegular:
|
||||
def _copy_with(self, cond):
|
||||
return self.__class__(cond)
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond.shape != other.cond.shape:
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device, skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
@ -29,14 +33,14 @@ class CONDRegular:
|
||||
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
def process_cond(self, batch_size, area, **kwargs):
|
||||
data = self.cond
|
||||
if area is not None:
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
|
||||
|
||||
|
||||
class CONDCrossAttn(CONDRegular):
|
||||
@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device: skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
@ -92,10 +99,10 @@ class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
|
||||
629
comfy/context_windows.py
Normal file
629
comfy/context_windows.py
Normal file
@ -0,0 +1,629 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
import numpy as np
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
|
||||
class ContextWindowABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get torch.Tensor applicable to current window.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
class ContextHandlerABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
window[idx] = full[idx]
|
||||
return window.to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
def get_region_index(self, num_regions: int) -> int:
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
RESIZE_COND_ITEM = "resize_cond_item"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
@dataclass
|
||||
class ContextFuseMethod:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
self.context_overlap = context_overlap
|
||||
self.context_stride = context_stride
|
||||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
if control.previous_controlnet is not None:
|
||||
self.prepare_control_objects(control.previous_controlnet, device)
|
||||
return control
|
||||
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
|
||||
if cond_in is None:
|
||||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# if multiple conds, split based on primary region
|
||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||
region = window.get_region_index(len(cond_in))
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
cond_in = [cond_in[region]]
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
|
||||
for key in actual_cond:
|
||||
try:
|
||||
cond_item = actual_cond[key]
|
||||
if isinstance(cond_item, torch.Tensor):
|
||||
# check that tensor is the expected length - x.size(0)
|
||||
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
|
||||
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
|
||||
actual_cond_item = window.get_tensor(cond_item)
|
||||
resized_actual_cond[key] = actual_cond_item.to(device)
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item.to(device)
|
||||
# look for control
|
||||
elif key == "control":
|
||||
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
|
||||
elif isinstance(cond_item, dict):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
# Allow callbacks to handle custom conditioning items
|
||||
handled = False
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(
|
||||
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
||||
):
|
||||
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
audio_cond = cond_value.cond
|
||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
resized_actual_cond[key] = new_cond_item
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item
|
||||
finally:
|
||||
del cond_item # just in case to prevent VRAM issues
|
||||
resized_cond.append(resized_actual_cond)
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
self._step = int(matches[0].item())
|
||||
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
|
||||
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
for pos, idx in enumerate(window.index_list):
|
||||
# bias is the influence of a specific index in relation to the whole context window
|
||||
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
|
||||
bias = max(1e-2, bias)
|
||||
# take weighted average relative to total bias of current idx
|
||||
for i in range(len(sub_conds_out)):
|
||||
bias_total = biases_final[i][idx]
|
||||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = tuple([slice(None)] * self.dim + [idx])
|
||||
pos_window = tuple([slice(None)] * self.dim + [pos])
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
window.add_window(counts_final[i], weights_tensor)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
|
||||
"ContextWindows_prepare_sampling",
|
||||
_prepare_sampling_wrapper
|
||||
)
|
||||
|
||||
|
||||
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
||||
model_options = extra_args.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is None:
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
"ContextWindows_sampler_sample",
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
for _ in range(dim):
|
||||
weights_tensor = weights_tensor.unsqueeze(0)
|
||||
for _ in range(total_dims - dim - 1):
|
||||
weights_tensor = weights_tensor.unsqueeze(-1)
|
||||
return weights_tensor
|
||||
|
||||
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
||||
total_dims = len(x_in.shape)
|
||||
shape = []
|
||||
for _ in range(dim):
|
||||
shape.append(1)
|
||||
shape.append(x_in.shape[dim])
|
||||
for _ in range(total_dims - dim - 1):
|
||||
shape.append(1)
|
||||
return shape
|
||||
|
||||
class ContextSchedules:
|
||||
UNIFORM_LOOPED = "looped_uniform"
|
||||
UNIFORM_STANDARD = "standard_uniform"
|
||||
STATIC_STANDARD = "standard_static"
|
||||
BATCHED = "batched"
|
||||
|
||||
|
||||
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
|
||||
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames < handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
return windows
|
||||
|
||||
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
|
||||
# instead, they get shifted to the corresponding end of the frames.
|
||||
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# first, obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (-handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
||||
delete_idxs = []
|
||||
win_i = 0
|
||||
while win_i < len(windows):
|
||||
# if window is rolls over itself, need to shift it
|
||||
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
||||
if is_roll:
|
||||
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
||||
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
||||
# check if next window (cyclical) is missing roll_val
|
||||
if roll_val not in windows[(win_i+1) % len(windows)]:
|
||||
# need to insert new window here - just insert window starting at roll_val
|
||||
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
|
||||
# delete window if it's not unique
|
||||
for pre_i in range(0, win_i):
|
||||
if windows[win_i] == windows[pre_i]:
|
||||
delete_idxs.append(win_i)
|
||||
break
|
||||
win_i += 1
|
||||
|
||||
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
||||
delete_idxs.reverse()
|
||||
for i in delete_idxs:
|
||||
windows.pop(i)
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows
|
||||
delta = handler.context_length - handler.context_overlap
|
||||
for start_idx in range(0, num_frames, delta):
|
||||
# if past the end of frames, move start_idx back to allow same context_length
|
||||
ending = start_idx + handler.context_length
|
||||
if ending >= num_frames:
|
||||
final_delta = ending - num_frames
|
||||
final_start_idx = start_idx - final_delta
|
||||
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
|
||||
break
|
||||
windows.append(list(range(start_idx, start_idx + handler.context_length)))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows;
|
||||
# no overlap, just cut up based on context_length;
|
||||
# last window size will be different if num_frames % opts.context_length != 0
|
||||
for start_idx in range(0, num_frames, handler.context_length):
|
||||
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
|
||||
return [list(range(num_frames))]
|
||||
|
||||
|
||||
CONTEXT_MAPPING = {
|
||||
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
|
||||
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
|
||||
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
|
||||
ContextSchedules.BATCHED: create_windows_batched,
|
||||
}
|
||||
|
||||
|
||||
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
func = CONTEXT_MAPPING.get(context_schedule, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
# weight is the same for all
|
||||
return [1.0] * length
|
||||
|
||||
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
# weight is based on the distance away from the edge of the context window;
|
||||
# based on weighted average concept in FreeNoise paper
|
||||
if length % 2 == 0:
|
||||
max_weight = length // 2
|
||||
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
|
||||
else:
|
||||
max_weight = (length + 1) // 2
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
FLAT = "flat"
|
||||
PYRAMID = "pyramid"
|
||||
RELATIVE = "relative"
|
||||
OVERLAP_LINEAR = "overlap-linear"
|
||||
|
||||
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
|
||||
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
|
||||
|
||||
|
||||
FUSE_MAPPING = {
|
||||
ContextFuseMethods.FLAT: create_weights_flat,
|
||||
ContextFuseMethods.PYRAMID: create_weights_pyramid,
|
||||
ContextFuseMethods.RELATIVE: create_weights_pyramid,
|
||||
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
|
||||
}
|
||||
|
||||
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
|
||||
func = FUSE_MAPPING.get(fuse_method, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
|
||||
return ContextFuseMethod(fuse_method, func)
|
||||
|
||||
# Returns fraction that has denominator that is a power of 2
|
||||
def ordered_halving(val):
|
||||
# get binary value, padded with 0s for 64 bits
|
||||
bin_str = f"{val:064b}"
|
||||
# flip binary value, padding included
|
||||
bin_flip = bin_str[::-1]
|
||||
# convert binary to int
|
||||
as_int = int(bin_flip, 2)
|
||||
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
|
||||
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
|
||||
return as_int / (1 << 64)
|
||||
|
||||
|
||||
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
||||
all_indexes = list(range(num_frames))
|
||||
for w in windows:
|
||||
for val in w:
|
||||
try:
|
||||
all_indexes.remove(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return all_indexes
|
||||
|
||||
|
||||
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
||||
prev_val = -1
|
||||
for i, val in enumerate(window):
|
||||
val = val % num_frames
|
||||
if val < prev_val:
|
||||
return True, i
|
||||
prev_val = val
|
||||
return False, -1
|
||||
|
||||
|
||||
def shift_window_to_start(window: list[int], num_frames: int):
|
||||
start_val = window[0]
|
||||
for i in range(len(window)):
|
||||
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
||||
# 2) add num_frames and take modulus to get adjusted vals
|
||||
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
||||
|
||||
|
||||
def shift_window_to_end(window: list[int], num_frames: int):
|
||||
# 1) shift window to start
|
||||
shift_window_to_start(window, num_frames)
|
||||
end_val = window[-1]
|
||||
end_delta = num_frames - end_val - 1
|
||||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
||||
|
||||
|
||||
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||
logging.info("Context windows: Applying FreeNoise")
|
||||
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||
latent_video_length = noise.shape[dim]
|
||||
delta = context_length - context_overlap
|
||||
|
||||
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||
place_idx = start_idx + context_length
|
||||
|
||||
actual_delta = min(delta, latent_video_length - place_idx)
|
||||
if actual_delta <= 0:
|
||||
break
|
||||
|
||||
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||
|
||||
source_slice = [slice(None)] * noise.ndim
|
||||
source_slice[dim] = list_idx
|
||||
target_slice = [slice(None)] * noise.ndim
|
||||
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||
|
||||
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||
|
||||
return noise
|
||||
@ -28,6 +28,7 @@ import comfy.model_detection
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
import comfy.latent_formats
|
||||
import comfy.model_base
|
||||
|
||||
import comfy.cldm.cldm
|
||||
import comfy.t2i_adapter.adapter
|
||||
@ -35,6 +36,7 @@ import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@ -43,7 +45,6 @@ if TYPE_CHECKING:
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
@ -236,11 +237,11 @@ class ControlNet(ControlBase):
|
||||
self.cond_hint = None
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
compression_ratio *= self.vae.spacial_compression_encode()
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
@ -252,7 +253,10 @@ class ControlNet(ControlBase):
|
||||
to_concat = []
|
||||
for c in self.extra_concat_orig:
|
||||
c = c.to(self.cond_hint.device)
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
|
||||
if c.ndim < self.cond_hint.ndim:
|
||||
c = c.unsqueeze(2)
|
||||
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
|
||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||
|
||||
@ -265,12 +269,12 @@ class ControlNet(ControlBase):
|
||||
for c in self.extra_conds:
|
||||
temp = cond.get(c, None)
|
||||
if temp is not None:
|
||||
extra[c] = temp.to(dtype)
|
||||
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
|
||||
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype=None)
|
||||
|
||||
def copy(self):
|
||||
@ -306,11 +310,13 @@ class ControlLoraOps:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
@ -346,12 +352,13 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||
@ -390,8 +397,9 @@ class ControlLora(ControlNet):
|
||||
pass
|
||||
|
||||
for k in self.control_weights:
|
||||
if k not in {"lora_controlnet"}:
|
||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||
if (k not in {"lora_controlnet"}):
|
||||
if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k):
|
||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||
|
||||
def copy(self):
|
||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
||||
@ -581,6 +589,22 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
|
||||
|
||||
extra_condition_channels = 0
|
||||
concat_mask = False
|
||||
if control_latent_channels == 68: #inpaint controlnet
|
||||
extra_condition_channels = control_latent_channels - 64
|
||||
concat_mask = True
|
||||
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
extra_conds = []
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@ -654,8 +678,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
|
||||
@ -1,55 +1,10 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
from .ldm.modules.attention import CrossAttention, FeedForward
|
||||
import comfy.ops
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = ops.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * torch.nn.functional.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
ops.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
ops.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class GatedCrossAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
|
||||
@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
class Dinov2MLP(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
mlp_ratio = 4
|
||||
hidden_features = int(hidden_size * mlp_ratio)
|
||||
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.fc1(hidden_state)
|
||||
hidden_state = torch.nn.functional.gelu(hidden_state)
|
||||
hidden_state = self.fc2(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
else:
|
||||
self.mlp = Dinov2MLP(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
for i, layer in enumerate(self.layer):
|
||||
x = layer(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
|
||||
22
comfy/image_encoders/dino2_large.json
Normal file
22
comfy/image_encoders/dino2_large.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"hidden_size": 1024,
|
||||
"use_mask_token": true,
|
||||
"patch_size": 14,
|
||||
"image_size": 518,
|
||||
"num_channels": 3,
|
||||
"num_attention_heads": 16,
|
||||
"initializer_range": 0.02,
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_hidden_layers": 24,
|
||||
"layer_norm_eps": 1e-6,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": false,
|
||||
"layerscale_value": 1.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
121
comfy/k_diffusion/sa_solver.py
Normal file
121
comfy/k_diffusion/sa_solver.py
Normal file
@ -0,0 +1,121 @@
|
||||
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
|
||||
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
|
||||
# Codebase ref: https://github.com/scxue/SA-Solver
|
||||
|
||||
import math
|
||||
from typing import Union, Callable
|
||||
import torch
|
||||
|
||||
|
||||
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
|
||||
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
|
||||
|
||||
Integral of exp((1 + tau^2) * x) * x^p dx
|
||||
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
|
||||
with base case p=0 where integral equals product_terms[0].
|
||||
|
||||
where
|
||||
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
|
||||
|
||||
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
|
||||
Return coefficients used by the SA-Solver in data prediction mode.
|
||||
|
||||
Args:
|
||||
s: Start time s.
|
||||
t: End time t.
|
||||
solver_order: Current order of the solver.
|
||||
tau_t: Stochastic strength parameter in the SDE.
|
||||
|
||||
Returns:
|
||||
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
|
||||
"""
|
||||
tau_mul = 1 + tau_t ** 2
|
||||
h = t - s
|
||||
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
|
||||
|
||||
# product_terms after factoring out exp((1 + tau^2) * t)
|
||||
# Includes (1 + tau^2) factor from outside the integral
|
||||
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
|
||||
|
||||
# Lower triangular recursive coefficient matrix
|
||||
# Accumulates recursive coefficients based on p / (1 + tau^2)
|
||||
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
|
||||
log_factorial = (p + 1).lgamma()
|
||||
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
|
||||
if tau_t > 0:
|
||||
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
|
||||
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
|
||||
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
|
||||
|
||||
return recursive_coeff_mat @ product_terms_factored
|
||||
|
||||
|
||||
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
|
||||
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
|
||||
tau_mul = 1 + tau_t ** 2
|
||||
h = lambda_t - lambda_s
|
||||
alpha_t = sigma_next * lambda_t.exp()
|
||||
if is_corrector_step:
|
||||
# Simplified 1-step (order-2) corrector
|
||||
b_1 = alpha_t * (0.5 * tau_mul * h)
|
||||
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
|
||||
else:
|
||||
# Simplified 2-step predictor
|
||||
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
|
||||
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
|
||||
return torch.stack([b_2, b_1])
|
||||
|
||||
|
||||
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
|
||||
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
|
||||
|
||||
The solver order corresponds to the number of input lambdas (half-logSNR points).
|
||||
|
||||
Args:
|
||||
sigma_next: Sigma at end time t.
|
||||
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
|
||||
lambda_s: Lambda at start time s.
|
||||
lambda_t: Lambda at end time t.
|
||||
tau_t: Stochastic strength parameter in the SDE.
|
||||
simple_order_2: Whether to enable the simple order-2 scheme.
|
||||
is_corrector_step: Flag for corrector step in simple order-2 mode.
|
||||
|
||||
Returns:
|
||||
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
|
||||
"""
|
||||
num_timesteps = curr_lambdas.shape[0]
|
||||
|
||||
if simple_order_2 and num_timesteps == 2:
|
||||
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
|
||||
|
||||
# Compute coefficients by solving a linear system from Lagrange basis interpolation
|
||||
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
|
||||
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
|
||||
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
|
||||
|
||||
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
|
||||
# = sigma_t * exp(lambda_t) = alpha_t
|
||||
# exp((1 + tau^2) * lambda_t) is extracted from the integral
|
||||
alpha_t = sigma_next * lambda_t.exp()
|
||||
return alpha_t * lagrange_integrals
|
||||
|
||||
|
||||
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
|
||||
"""Return a function that controls the stochasticity of SA-Solver.
|
||||
|
||||
When eta = 0, SA-Solver runs as ODE. The official approach uses
|
||||
time t to determine the SDE interval, while here we use sigma instead.
|
||||
|
||||
See:
|
||||
https://github.com/scxue/SA-Solver/blob/main/README.md
|
||||
"""
|
||||
|
||||
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
|
||||
if eta <= 0:
|
||||
return 0.0 # ODE
|
||||
|
||||
if isinstance(sigma, torch.Tensor):
|
||||
sigma = sigma.item()
|
||||
return eta if start_sigma >= sigma >= end_sigma else 0.0
|
||||
|
||||
return tau_func
|
||||
@ -1,4 +1,5 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
@ -8,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
@ -84,24 +86,24 @@ class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
self.cpu_tree = True
|
||||
if "cpu" in kwargs:
|
||||
self.cpu_tree = kwargs.pop("cpu")
|
||||
self.cpu_tree = kwargs.pop("cpu", True)
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
w0 = kwargs.pop('w0', None)
|
||||
if w0 is None:
|
||||
w0 = torch.zeros_like(x)
|
||||
self.batched = False
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
|
||||
elif isinstance(seed, (tuple, list)):
|
||||
if len(seed) != x.shape[0]:
|
||||
raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
|
||||
self.batched = True
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
if self.cpu_tree:
|
||||
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
||||
else:
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
seed = (seed,)
|
||||
if self.cpu_tree:
|
||||
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
|
||||
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
@ -109,11 +111,10 @@ class BatchedBrownianTree:
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
device, dtype = t0.device, t0.dtype
|
||||
if self.cpu_tree:
|
||||
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
||||
else:
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
|
||||
t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
@ -142,6 +143,43 @@ class BrownianTreeNoiseSampler:
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
def sigma_to_half_log_snr(sigma, model_sampling):
|
||||
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
# log((1 - t) / t) = log((1 - sigma) / sigma)
|
||||
return sigma.logit().neg()
|
||||
return sigma.log().neg()
|
||||
|
||||
|
||||
def half_log_snr_to_sigma(half_log_snr, model_sampling):
|
||||
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
# 1 / (1 + exp(half_log_snr))
|
||||
return half_log_snr.neg().sigmoid()
|
||||
return half_log_snr.neg().exp()
|
||||
|
||||
|
||||
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||
"""Adjust the first sigma to avoid invalid logSNR."""
|
||||
if len(sigmas) <= 1:
|
||||
return sigmas
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
if sigmas[0] >= 1:
|
||||
sigmas = sigmas.clone()
|
||||
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
|
||||
return sigmas
|
||||
|
||||
|
||||
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
|
||||
return torch.expm1(h)
|
||||
|
||||
|
||||
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
|
||||
return (torch.expm1(h) - h) / h
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
@ -384,9 +422,13 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@ -682,6 +724,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
@ -693,38 +736,49 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
x = x + d * dt
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
s = t + h * r
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
alpha_s = sigmas[i] * lambda_s.exp()
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 1
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
||||
s_ = t_fn(sd)
|
||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
|
||||
lambda_s_1_ = sd.log().neg()
|
||||
h_ = lambda_s_1_ - lambda_s
|
||||
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
|
||||
if eta > 0 and s_noise > 0:
|
||||
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
||||
t_next_ = t_fn(sd)
|
||||
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
|
||||
lambda_t_ = sd.log().neg()
|
||||
h_ = lambda_t_ - lambda_s
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
||||
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
|
||||
return x
|
||||
|
||||
|
||||
@ -753,6 +807,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
@ -768,9 +823,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
h, h_last = None, None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@ -781,26 +839,34 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
if eta:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
@ -814,6 +880,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
h, h_1, h_2 = None, None, None
|
||||
|
||||
@ -825,13 +895,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
|
||||
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||
|
||||
if h_2 is not None:
|
||||
# DPM-Solver++(3M) SDE
|
||||
r0 = h_1 / h
|
||||
r1 = h_2 / h
|
||||
d1_0 = (denoised - denoised_1) / r0
|
||||
@ -840,20 +913,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
d2 = (d1_0 - d1_1) / (r0 + r1)
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
phi_3 = phi_2 / h_eta - 0.5
|
||||
x = x + phi_2 * d1 - phi_3 * d2
|
||||
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
|
||||
elif h_1 is not None:
|
||||
# DPM-Solver++(2M) SDE
|
||||
r = h_1 / h
|
||||
d = (denoised - denoised_1) / r
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
x = x + phi_2 * d
|
||||
x = x + (alpha_t * phi_2) * d
|
||||
|
||||
if eta:
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||
|
||||
denoised_1, denoised_2 = denoised, denoised_1
|
||||
h_1, h_2 = h, h_1
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
@ -863,6 +938,17 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
@ -872,6 +958,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
@ -1009,7 +1096,9 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
if t_next == 0: # Denoising step
|
||||
x_next = denoised
|
||||
elif order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
||||
@ -1027,6 +1116,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
@ -1050,7 +1140,9 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
if t_next == 0: # Denoising step
|
||||
x_next = denoised
|
||||
elif order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
h_n = (t_next - t_cur)
|
||||
@ -1090,6 +1182,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
@torch.no_grad()
|
||||
@ -1140,39 +1233,22 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
return x_next
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
temp = [0]
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, temp[0])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
# Euler method
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
"""Ancestral sampling with Euler method steps (CFG++)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
|
||||
uncond_denoised = None
|
||||
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
@ -1181,15 +1257,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], temp[0])
|
||||
# Euler method
|
||||
x = denoised + d * sigma_down
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
|
||||
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
|
||||
|
||||
# DDIM stochastic sampling
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
|
||||
sigma_down = alpha_t * sigma_down
|
||||
|
||||
# Euler method
|
||||
x = alpha_t * denoised + sigma_down * d
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Euler method steps (CFG++)."""
|
||||
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
@ -1346,6 +1440,7 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
|
||||
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||
@ -1372,31 +1467,32 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
if i == 0:
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# Euler method
|
||||
if cfg_pp:
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
else:
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Gradient estimation
|
||||
if cfg_pp:
|
||||
|
||||
if i >= 1:
|
||||
# Gradient estimation
|
||||
d_bar = (ge_gamma - 1) * (d - old_d)
|
||||
x = denoised + d * sigmas[i + 1] + d_bar * dt
|
||||
else:
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
||||
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
@ -1404,12 +1500,18 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
def default_noise_scaler(sigma):
|
||||
return sigma * ((sigma ** 0.3).exp() + 10.0)
|
||||
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
|
||||
def default_er_sde_noise_scaler(x):
|
||||
return x * ((x ** 0.3).exp() + 10.0)
|
||||
|
||||
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
|
||||
num_integration_points = 200.0
|
||||
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
||||
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
|
||||
|
||||
old_denoised = None
|
||||
old_denoised_d = None
|
||||
|
||||
@ -1420,129 +1522,274 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
stage_used = min(max_stage, i + 1)
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
elif stage_used == 1:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
else:
|
||||
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
|
||||
x = r * x + (1 - r) * denoised
|
||||
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
||||
alpha_s = sigmas[i] / er_lambda_s
|
||||
alpha_t = sigmas[i + 1] / er_lambda_t
|
||||
r_alpha = alpha_t / alpha_s
|
||||
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
|
||||
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
sigma_step_size = -dt / num_integration_points
|
||||
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
|
||||
scaled_pos = noise_scaler(sigma_pos)
|
||||
# Stage 1 Euler
|
||||
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
||||
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * sigma_step_size
|
||||
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
|
||||
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
|
||||
if stage_used >= 2:
|
||||
dt = er_lambda_t - er_lambda_s
|
||||
lambda_step_size = -dt / num_integration_points
|
||||
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
||||
scaled_pos = noise_scaler(lambda_pos)
|
||||
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
|
||||
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
# Stage 2
|
||||
s = torch.sum(1 / scaled_pos) * lambda_step_size
|
||||
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
||||
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
if stage_used >= 3:
|
||||
# Stage 3
|
||||
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
|
||||
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
|
||||
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise > 0:
|
||||
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
'''
|
||||
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s = t + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s = s.neg().exp()
|
||||
continue
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
'''
|
||||
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s_1 = t + r_1 * h
|
||||
s_2 = t + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
||||
continue
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
||||
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
|
||||
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
# Step 2
|
||||
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
|
||||
if inject_noise:
|
||||
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
|
||||
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
# Step 2
|
||||
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
|
||||
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
|
||||
if inject_noise:
|
||||
segment_factor = (r_1 - r_2) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
|
||||
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
b3 = ei_h_phi_2(-h_eta) / r_2
|
||||
b1 = ei_h_phi_1(-h_eta) - b3
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
|
||||
if inject_noise:
|
||||
segment_factor = (r_2 - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
|
||||
if len(sigmas) <= 1:
|
||||
return x
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
|
||||
|
||||
if tau_func is None:
|
||||
# Use default interval for stochastic sampling
|
||||
start_sigma = model_sampling.percent_to_sigma(0.2)
|
||||
end_sigma = model_sampling.percent_to_sigma(0.8)
|
||||
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
|
||||
|
||||
max_used_order = max(predictor_order, corrector_order)
|
||||
x_pred = x # x: current state, x_pred: predicted next state
|
||||
|
||||
h = 0.0
|
||||
tau_t = 0.0
|
||||
noise = 0.0
|
||||
pred_list = []
|
||||
|
||||
# Lower order near the end to improve stability
|
||||
lower_order_to_end = sigmas[-1].item() == 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
# Evaluation
|
||||
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
pred_list.append(denoised)
|
||||
pred_list = pred_list[-max_used_order:]
|
||||
|
||||
predictor_order_used = min(predictor_order, len(pred_list))
|
||||
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
|
||||
corrector_order_used = 0
|
||||
else:
|
||||
corrector_order_used = min(corrector_order, len(pred_list))
|
||||
|
||||
if lower_order_to_end:
|
||||
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
|
||||
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
|
||||
|
||||
# Corrector
|
||||
if corrector_order_used == 0:
|
||||
# Update by the predicted state
|
||||
x = x_pred
|
||||
else:
|
||||
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
|
||||
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||
sigmas[i],
|
||||
curr_lambdas,
|
||||
lambdas[i - 1],
|
||||
lambdas[i],
|
||||
tau_t,
|
||||
simple_order_2,
|
||||
is_corrector_step=True,
|
||||
)
|
||||
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
|
||||
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
|
||||
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
# The noise from the previous predictor step
|
||||
x = x + noise
|
||||
|
||||
if use_pece:
|
||||
# Evaluate the corrected state
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
pred_list[-1] = denoised
|
||||
|
||||
# Predictor
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
tau_t = tau_func(sigmas[i + 1])
|
||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
||||
sigmas[i + 1],
|
||||
curr_lambdas,
|
||||
lambdas[i],
|
||||
lambdas[i + 1],
|
||||
tau_t,
|
||||
simple_order_2,
|
||||
is_corrector_step=False,
|
||||
)
|
||||
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
|
||||
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
||||
h = lambdas[i + 1] - lambdas[i]
|
||||
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
|
||||
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||
x_pred = x_pred + noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||
|
||||
@ -6,6 +6,7 @@ class LatentFormat:
|
||||
latent_dimensions = 2
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
@ -178,6 +179,54 @@ class Flux(SD3):
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
[0.0058, 0.0113, 0.0073],
|
||||
[0.0495, 0.0443, 0.0836],
|
||||
[-0.0099, 0.0096, 0.0644],
|
||||
[0.2144, 0.3009, 0.3652],
|
||||
[0.0166, -0.0039, -0.0054],
|
||||
[0.0157, 0.0103, -0.0160],
|
||||
[-0.0398, 0.0902, -0.0235],
|
||||
[-0.0052, 0.0095, 0.0109],
|
||||
[-0.3527, -0.2712, -0.1666],
|
||||
[-0.0301, -0.0356, -0.0180],
|
||||
[-0.0107, 0.0078, 0.0013],
|
||||
[0.0746, 0.0090, -0.0941],
|
||||
[0.0156, 0.0169, 0.0070],
|
||||
[-0.0034, -0.0040, -0.0114],
|
||||
[0.0032, 0.0181, 0.0080],
|
||||
[-0.0939, -0.0008, 0.0186],
|
||||
[0.0018, 0.0043, 0.0104],
|
||||
[0.0284, 0.0056, -0.0127],
|
||||
[-0.0024, -0.0022, -0.0030],
|
||||
[0.1207, -0.0026, 0.0065],
|
||||
[0.0128, 0.0101, 0.0142],
|
||||
[0.0137, -0.0072, -0.0007],
|
||||
[0.0095, 0.0092, -0.0059],
|
||||
[0.0000, -0.0077, -0.0049],
|
||||
[-0.0465, -0.0204, -0.0312],
|
||||
[0.0095, 0.0012, -0.0066],
|
||||
[0.0290, -0.0034, 0.0025],
|
||||
[0.0220, 0.0169, -0.0048],
|
||||
[-0.0332, -0.0457, -0.0468],
|
||||
[-0.0085, 0.0389, 0.0609],
|
||||
[-0.0076, 0.0003, -0.0043],
|
||||
[-0.0111, -0.0460, -0.0614],
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
@ -382,6 +431,7 @@ class HunyuanVideo(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
taesd_decoder_name = "taehv"
|
||||
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
@ -445,7 +495,7 @@ class Wan21(LatentFormat):
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
|
||||
self.taesd_decoder_name = None #TODO
|
||||
self.taesd_decoder_name = "lighttaew2_1"
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
@ -457,11 +507,232 @@ class Wan21(LatentFormat):
|
||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * latents_std / self.scale_factor + latents_mean
|
||||
|
||||
class Wan22(Wan21):
|
||||
latent_channels = 48
|
||||
latent_dimensions = 3
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.0119, 0.0103, 0.0046],
|
||||
[-0.1062, -0.0504, 0.0165],
|
||||
[ 0.0140, 0.0409, 0.0491],
|
||||
[-0.0813, -0.0677, 0.0607],
|
||||
[ 0.0656, 0.0851, 0.0808],
|
||||
[ 0.0264, 0.0463, 0.0912],
|
||||
[ 0.0295, 0.0326, 0.0590],
|
||||
[-0.0244, -0.0270, 0.0025],
|
||||
[ 0.0443, -0.0102, 0.0288],
|
||||
[-0.0465, -0.0090, -0.0205],
|
||||
[ 0.0359, 0.0236, 0.0082],
|
||||
[-0.0776, 0.0854, 0.1048],
|
||||
[ 0.0564, 0.0264, 0.0561],
|
||||
[ 0.0006, 0.0594, 0.0418],
|
||||
[-0.0319, -0.0542, -0.0637],
|
||||
[-0.0268, 0.0024, 0.0260],
|
||||
[ 0.0539, 0.0265, 0.0358],
|
||||
[-0.0359, -0.0312, -0.0287],
|
||||
[-0.0285, -0.1032, -0.1237],
|
||||
[ 0.1041, 0.0537, 0.0622],
|
||||
[-0.0086, -0.0374, -0.0051],
|
||||
[ 0.0390, 0.0670, 0.2863],
|
||||
[ 0.0069, 0.0144, 0.0082],
|
||||
[ 0.0006, -0.0167, 0.0079],
|
||||
[ 0.0313, -0.0574, -0.0232],
|
||||
[-0.1454, -0.0902, -0.0481],
|
||||
[ 0.0714, 0.0827, 0.0447],
|
||||
[-0.0304, -0.0574, -0.0196],
|
||||
[ 0.0401, 0.0384, 0.0204],
|
||||
[-0.0758, -0.0297, -0.0014],
|
||||
[ 0.0568, 0.1307, 0.1372],
|
||||
[-0.0055, -0.0310, -0.0380],
|
||||
[ 0.0239, -0.0305, 0.0325],
|
||||
[-0.0663, -0.0673, -0.0140],
|
||||
[-0.0416, -0.0047, -0.0023],
|
||||
[ 0.0166, 0.0112, -0.0093],
|
||||
[-0.0211, 0.0011, 0.0331],
|
||||
[ 0.1833, 0.1466, 0.2250],
|
||||
[-0.0368, 0.0370, 0.0295],
|
||||
[-0.3441, -0.3543, -0.2008],
|
||||
[-0.0479, -0.0489, -0.0420],
|
||||
[-0.0660, -0.0153, 0.0800],
|
||||
[-0.0101, 0.0068, 0.0156],
|
||||
[-0.0690, -0.0452, -0.0927],
|
||||
[-0.0145, 0.0041, 0.0015],
|
||||
[ 0.0421, 0.0451, 0.0373],
|
||||
[ 0.0504, -0.0483, -0.0356],
|
||||
[-0.0837, 0.0168, 0.0055]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.taesd_decoder_name = "lighttaew2_2"
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
self.latents_std = torch.tensor([
|
||||
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
class HunyuanImage21(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 2
|
||||
scale_factor = 0.75289
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.0154, -0.0397, -0.0521],
|
||||
[ 0.0005, 0.0093, 0.0006],
|
||||
[-0.0805, -0.0773, -0.0586],
|
||||
[-0.0494, -0.0487, -0.0498],
|
||||
[-0.0212, -0.0076, -0.0261],
|
||||
[-0.0179, -0.0417, -0.0505],
|
||||
[ 0.0158, 0.0310, 0.0239],
|
||||
[ 0.0409, 0.0516, 0.0201],
|
||||
[ 0.0350, 0.0553, 0.0036],
|
||||
[-0.0447, -0.0327, -0.0479],
|
||||
[-0.0038, -0.0221, -0.0365],
|
||||
[-0.0423, -0.0718, -0.0654],
|
||||
[ 0.0039, 0.0368, 0.0104],
|
||||
[ 0.0655, 0.0217, 0.0122],
|
||||
[ 0.0490, 0.1638, 0.2053],
|
||||
[ 0.0932, 0.0829, 0.0650],
|
||||
[-0.0186, -0.0209, -0.0135],
|
||||
[-0.0080, -0.0076, -0.0148],
|
||||
[-0.0284, -0.0201, 0.0011],
|
||||
[-0.0642, -0.0294, -0.0777],
|
||||
[-0.0035, 0.0076, -0.0140],
|
||||
[ 0.0519, 0.0731, 0.0887],
|
||||
[-0.0102, 0.0095, 0.0704],
|
||||
[ 0.0068, 0.0218, -0.0023],
|
||||
[-0.0726, -0.0486, -0.0519],
|
||||
[ 0.0260, 0.0295, 0.0263],
|
||||
[ 0.0250, 0.0333, 0.0341],
|
||||
[ 0.0168, -0.0120, -0.0174],
|
||||
[ 0.0226, 0.1037, 0.0114],
|
||||
[ 0.2577, 0.1906, 0.1604],
|
||||
[-0.0646, -0.0137, -0.0018],
|
||||
[-0.0112, 0.0309, 0.0358],
|
||||
[-0.0347, 0.0146, -0.0481],
|
||||
[ 0.0234, 0.0179, 0.0201],
|
||||
[ 0.0157, 0.0313, 0.0225],
|
||||
[ 0.0423, 0.0675, 0.0524],
|
||||
[-0.0031, 0.0027, -0.0255],
|
||||
[ 0.0447, 0.0555, 0.0330],
|
||||
[-0.0152, 0.0103, 0.0299],
|
||||
[-0.0755, -0.0489, -0.0635],
|
||||
[ 0.0853, 0.0788, 0.1017],
|
||||
[-0.0272, -0.0294, -0.0471],
|
||||
[ 0.0440, 0.0400, -0.0137],
|
||||
[ 0.0335, 0.0317, -0.0036],
|
||||
[-0.0344, -0.0621, -0.0984],
|
||||
[-0.0127, -0.0630, -0.0620],
|
||||
[-0.0648, 0.0360, 0.0924],
|
||||
[-0.0781, -0.0801, -0.0409],
|
||||
[ 0.0363, 0.0613, 0.0499],
|
||||
[ 0.0238, 0.0034, 0.0041],
|
||||
[-0.0135, 0.0258, 0.0310],
|
||||
[ 0.0614, 0.1086, 0.0589],
|
||||
[ 0.0428, 0.0350, 0.0205],
|
||||
[ 0.0153, 0.0173, -0.0018],
|
||||
[-0.0288, -0.0455, -0.0091],
|
||||
[ 0.0344, 0.0109, -0.0157],
|
||||
[-0.0205, -0.0247, -0.0187],
|
||||
[ 0.0487, 0.0126, 0.0064],
|
||||
[-0.0220, -0.0013, 0.0074],
|
||||
[-0.0203, -0.0094, -0.0048],
|
||||
[-0.0719, 0.0429, -0.0442],
|
||||
[ 0.1042, 0.0497, 0.0356],
|
||||
[-0.0659, -0.0578, -0.0280],
|
||||
[-0.0060, -0.0322, -0.0234]]
|
||||
|
||||
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
||||
|
||||
class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
def process_in(self, latent):
|
||||
out = latent * self.scale_factor
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
def process_out(self, latent):
|
||||
z = latent / self.scale_factor
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
return z
|
||||
|
||||
class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors = [
|
||||
[ 0.0568, -0.0521, -0.0131],
|
||||
[ 0.0014, 0.0735, 0.0326],
|
||||
[ 0.0186, 0.0531, -0.0138],
|
||||
[-0.0031, 0.0051, 0.0288],
|
||||
[ 0.0110, 0.0556, 0.0432],
|
||||
[-0.0041, -0.0023, -0.0485],
|
||||
[ 0.0530, 0.0413, 0.0253],
|
||||
[ 0.0283, 0.0251, 0.0339],
|
||||
[ 0.0277, -0.0372, -0.0093],
|
||||
[ 0.0393, 0.0944, 0.1131],
|
||||
[ 0.0020, 0.0251, 0.0037],
|
||||
[-0.0017, 0.0012, 0.0234],
|
||||
[ 0.0468, 0.0436, 0.0203],
|
||||
[ 0.0354, 0.0439, -0.0233],
|
||||
[ 0.0090, 0.0123, 0.0346],
|
||||
[ 0.0382, 0.0029, 0.0217],
|
||||
[ 0.0261, -0.0300, 0.0030],
|
||||
[-0.0088, -0.0220, -0.0283],
|
||||
[-0.0272, -0.0121, -0.0363],
|
||||
[-0.0664, -0.0622, 0.0144],
|
||||
[ 0.0414, 0.0479, 0.0529],
|
||||
[ 0.0355, 0.0612, -0.0247],
|
||||
[ 0.0147, 0.0264, 0.0174],
|
||||
[ 0.0438, 0.0038, 0.0542],
|
||||
[ 0.0431, -0.0573, -0.0033],
|
||||
[-0.0162, -0.0211, -0.0406],
|
||||
[-0.0487, -0.0295, -0.0393],
|
||||
[ 0.0005, -0.0109, 0.0253],
|
||||
[ 0.0296, 0.0591, 0.0353],
|
||||
[ 0.0119, 0.0181, -0.0306],
|
||||
[-0.0085, -0.0362, 0.0229],
|
||||
[ 0.0005, -0.0106, 0.0242]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 0.9990943042622529
|
||||
|
||||
class Hunyuan3Dv2_1(LatentFormat):
|
||||
scale_factor = 1.0039506158752403
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
@ -470,3 +741,20 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 1.0, 0.0, 0.0 ],
|
||||
[ 0.0, 1.0, 0.0 ],
|
||||
[ 0.0, 0.0, 1.0 ]
|
||||
]
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
@ -133,6 +133,7 @@ class Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
@ -140,6 +141,7 @@ class Attention(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
):
|
||||
|
||||
N = hidden_states.shape[0]
|
||||
@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
attn_output, _ = self.attn(
|
||||
@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=None,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=None,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if self.use_adaln_single:
|
||||
@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
@ -313,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
transformer_options={},
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
@ -338,12 +340,34 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
|
||||
controlnet_scale, lyrics_strength, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
@ -371,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
@ -380,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
|
||||
else:
|
||||
self.source_sample_rate = source_sample_rate
|
||||
|
||||
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
])
|
||||
@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
|
||||
self.scale_factor = 0.1786
|
||||
self.shift_factor = -1.9091
|
||||
|
||||
def load_audio(self, audio_path):
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
return audio, sr
|
||||
|
||||
def forward_mel(self, audios):
|
||||
mels = []
|
||||
for i in range(len(audios)):
|
||||
@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
|
||||
latent = self.dcae.encoder(mel.unsqueeze(0))
|
||||
latents.append(latent)
|
||||
latents = torch.cat(latents, dim=0)
|
||||
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
||||
latents = (latents - self.shift_factor) * self.scale_factor
|
||||
return latents
|
||||
# return latents, latent_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents, audio_lengths=None, sr=None):
|
||||
@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
|
||||
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
||||
|
||||
if sr is not None:
|
||||
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
||||
wav = torchaudio.functional.resample(wav, 44100, sr)
|
||||
# wav = resampler(wav)
|
||||
else:
|
||||
sr = 44100
|
||||
pred_wavs.append(wav)
|
||||
@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
|
||||
if audio_lengths is not None:
|
||||
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||
return torch.stack(pred_wavs)
|
||||
# return sr, pred_wavs
|
||||
|
||||
def forward(self, audios, audio_lengths=None, sr=None):
|
||||
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||
|
||||
@ -298,7 +298,8 @@ class Attention(nn.Module):
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None
|
||||
causal = None,
|
||||
transformer_options={},
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
@ -363,7 +364,7 @@ class Attention(nn.Module):
|
||||
heads_per_kv_head = h // kv_h
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = self.to_out(out)
|
||||
|
||||
if mask is not None:
|
||||
@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
|
||||
global_cond=None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None
|
||||
rotary_pos_emb = None,
|
||||
transformer_options={}
|
||||
):
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
|
||||
@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = x + residual
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
@ -632,7 +635,7 @@ class ContinuousTransformer(nn.Module):
|
||||
# Attention layers
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
|
||||
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
@ -84,7 +85,7 @@ class SingleAttention(nn.Module):
|
||||
)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c):
|
||||
def forward(self, c, transformer_options={}):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
|
||||
@ -94,7 +95,7 @@ class SingleAttention(nn.Module):
|
||||
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
c = self.w1o(output)
|
||||
return c
|
||||
|
||||
@ -143,7 +144,7 @@ class DoubleAttention(nn.Module):
|
||||
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x):
|
||||
def forward(self, c, x, transformer_options={}):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
@ -167,7 +168,7 @@ class DoubleAttention(nn.Module):
|
||||
torch.cat([cv, xv], dim=1),
|
||||
)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||
c = self.w1o(c)
|
||||
@ -206,7 +207,7 @@ class MMDiTBlock(nn.Module):
|
||||
self.is_last = is_last
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x, global_cond, **kwargs):
|
||||
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
|
||||
|
||||
cres, xres = c, x
|
||||
|
||||
@ -224,7 +225,7 @@ class MMDiTBlock(nn.Module):
|
||||
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||
|
||||
# attention
|
||||
c, x = self.attn(c, x)
|
||||
c, x = self.attn(c, x, transformer_options=transformer_options)
|
||||
|
||||
|
||||
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||
@ -254,13 +255,13 @@ class DiTBlock(nn.Module):
|
||||
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, cx, global_cond, **kwargs):
|
||||
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
|
||||
cxres = cx
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||
global_cond
|
||||
).chunk(6, dim=1)
|
||||
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||
cx = self.attn(cx)
|
||||
cx = self.attn(cx, transformer_options=transformer_options)
|
||||
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||
@ -436,6 +437,13 @@ class MMDiT(nn.Module):
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
@ -465,13 +473,14 @@ class MMDiT(nn.Module):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"])
|
||||
args["vec"],
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
@ -480,13 +489,13 @@ class MMDiT(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
||||
@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
|
||||
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
def forward(self, q, k, v, transformer_options={}):
|
||||
q = self.to_q(q)
|
||||
k = self.to_k(k)
|
||||
v = self.to_v(v)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
|
||||
|
||||
return self.out_proj(out)
|
||||
|
||||
@ -47,13 +47,13 @@ class Attention2D(nn.Module):
|
||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, kv, self_attn=False):
|
||||
def forward(self, x, kv, self_attn=False, transformer_options={}):
|
||||
orig_shape = x.shape
|
||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||
if self_attn:
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv)
|
||||
x = self.attn(x, kv, kv, transformer_options=transformer_options)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
|
||||
@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
|
||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, kv):
|
||||
def forward(self, x, kv, transformer_options={}):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -173,7 +173,7 @@ class StageB(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
def _down_encode(self, x, r_embed, clip, transformer_options={}):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@ -187,7 +187,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -199,7 +199,7 @@ class StageB(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@ -216,7 +216,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -228,7 +228,7 @@ class StageB(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
||||
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
|
||||
if pixels is None:
|
||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||
|
||||
@ -245,8 +245,8 @@ class StageB(nn.Module):
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
level_outputs = self._down_encode(x, r_embed, clip)
|
||||
x = self._up_decode(level_outputs, r_embed, clip)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@ -182,7 +182,7 @@ class StageC(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@ -201,7 +201,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -213,7 +213,7 @@ class StageC(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@ -235,7 +235,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@ -247,7 +247,7 @@ class StageC(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
@ -262,8 +262,8 @@ class StageC(nn.Module):
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
# TODO: remove this in a few months
|
||||
SingleStreamBlock = None
|
||||
DoubleStreamBlock = None
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@ -48,124 +48,6 @@ class Approximator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@ -5,17 +5,18 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
@ -39,7 +40,8 @@ class ChromaParams:
|
||||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
|
||||
txt_ids_dims: list
|
||||
vec_in_dim: int
|
||||
|
||||
|
||||
|
||||
@ -89,6 +91,7 @@ class Chroma(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -97,7 +100,7 @@ class Chroma(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
@ -150,8 +153,6 @@ class Chroma(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
@ -179,7 +180,10 @@ class Chroma(nn.Module):
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
@ -192,14 +196,16 @@ class Chroma(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": double_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@ -208,7 +214,8 @@ class Chroma(nn.Module):
|
||||
txt=txt,
|
||||
vec=double_mod,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -219,7 +226,10 @@ class Chroma(nn.Module):
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
@ -228,17 +238,19 @@ class Chroma(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": single_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@ -248,19 +260,29 @@ class Chroma(nn.Module):
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
if hasattr(self, "final_layer"):
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
if img.ndim != 3 or context.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
||||
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
@ -268,4 +290,4 @@ class Chroma(nn.Module):
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
|
||||
|
||||
206
comfy/ldm/chroma_radiance/layers.py
Normal file
206
comfy/ldm/chroma_radiance/layers.py
Normal file
@ -0,0 +1,206 @@
|
||||
# Adapted from https://github.com/lodestone-rock/flow
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
An embedder module that combines input features with a 2D positional
|
||||
encoding that mimics the Discrete Cosine Transform (DCT).
|
||||
|
||||
This module takes an input tensor of shape (B, P^2, C), where P is the
|
||||
patch size, and enriches it with positional information before projecting
|
||||
it to a new hidden size.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_size_input: int,
|
||||
max_freqs: int,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
"""
|
||||
Initializes the NerfEmbedder.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input tensor.
|
||||
hidden_size_input (int): The desired dimension of the output embedding.
|
||||
max_freqs (int): The number of frequency components to use for both
|
||||
the x and y dimensions of the positional encoding.
|
||||
The total number of positional features will be max_freqs^2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
|
||||
# A linear layer to project the concatenated input features and
|
||||
# positional encodings to the final output dimension.
|
||||
self.embedder = nn.Sequential(
|
||||
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
||||
|
||||
The LRU cache is a performance optimization that avoids recomputing the
|
||||
same positional grid on every forward pass.
|
||||
|
||||
Args:
|
||||
patch_size (int): The side length of the square input patch.
|
||||
device: The torch device to create the tensors on.
|
||||
dtype: The torch dtype for the tensors.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (1, patch_size^2, max_freqs^2) containing the
|
||||
positional embeddings.
|
||||
"""
|
||||
# Create normalized 1D coordinate grids from 0 to 1.
|
||||
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
|
||||
# Create a 2D meshgrid of coordinates.
|
||||
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||
|
||||
# Reshape positions to be broadcastable with frequencies.
|
||||
# Shape becomes (patch_size^2, 1, 1).
|
||||
pos_x = pos_x.reshape(-1, 1, 1)
|
||||
pos_y = pos_y.reshape(-1, 1, 1)
|
||||
|
||||
# Create a 1D tensor of frequency values from 0 to max_freqs-1.
|
||||
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
|
||||
|
||||
# Reshape frequencies to be broadcastable for creating 2D basis functions.
|
||||
# freqs_x shape: (1, max_freqs, 1)
|
||||
# freqs_y shape: (1, 1, max_freqs)
|
||||
freqs_x = freqs[None, :, None]
|
||||
freqs_y = freqs[None, None, :]
|
||||
|
||||
# A custom weighting coefficient, not part of standard DCT.
|
||||
# This seems to down-weight the contribution of higher-frequency interactions.
|
||||
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||
|
||||
# Calculate the 1D cosine basis functions for x and y coordinates.
|
||||
# This is the core of the DCT formulation.
|
||||
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||
|
||||
# Combine the 1D basis functions to create 2D basis functions by element-wise
|
||||
# multiplication, and apply the custom coefficients. Broadcasting handles the
|
||||
# combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
|
||||
# The result is flattened into a feature vector for each position.
|
||||
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the embedder.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor of shape (B, P^2, C).
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (B, P^2, hidden_size_input).
|
||||
"""
|
||||
# Get the batch size, number of pixels, and number of channels.
|
||||
B, P2, C = inputs.shape
|
||||
|
||||
# Infer the patch side length from the number of pixels (P^2).
|
||||
patch_size = int(P2 ** 0.5)
|
||||
|
||||
input_dtype = inputs.dtype
|
||||
inputs = inputs.to(dtype=self.dtype)
|
||||
|
||||
# Fetch the pre-computed or cached positional embeddings.
|
||||
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
|
||||
|
||||
# Repeat the positional embeddings for each item in the batch.
|
||||
dct = dct.repeat(B, 1, 1)
|
||||
|
||||
# Concatenate the original input features with the positional embeddings
|
||||
# along the feature dimension.
|
||||
inputs = torch.cat((inputs, dct), dim=-1)
|
||||
|
||||
# Project the combined tensor to the target hidden size.
|
||||
return self.embedder(inputs).to(dtype=input_dtype)
|
||||
|
||||
|
||||
class NerfGLUBlock(nn.Module):
|
||||
"""
|
||||
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
||||
"""
|
||||
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
# The total number of parameters for the MLP is increased to accommodate
|
||||
# the gate, value, and output projection matrices.
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_x, hidden_size_x = x.shape
|
||||
mlp_params = self.param_generator(s)
|
||||
|
||||
# Split the generated parameters into three parts for the gate, value, and output projection.
|
||||
fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
|
||||
|
||||
# Reshape the parameters into matrices for batch matrix multiplication.
|
||||
fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
|
||||
|
||||
# Normalize the generated weight matrices as in the original implementation.
|
||||
fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
|
||||
fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
|
||||
fc2 = torch.nn.functional.normalize(fc2, dim=-2)
|
||||
|
||||
res_x = x
|
||||
x = self.norm(x)
|
||||
|
||||
# Apply the final output projection.
|
||||
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
|
||||
|
||||
return x + res_x
|
||||
|
||||
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
|
||||
|
||||
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
|
||||
335
comfy/ldm/chroma_radiance/model.py
Normal file
335
comfy/ldm/chroma_radiance/model.py
Normal file
@ -0,0 +1,335 @@
|
||||
# Credits:
|
||||
# Original Flux code can be found on: https://github.com/black-forest-labs/flux
|
||||
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
Approximator,
|
||||
)
|
||||
from .layers import (
|
||||
NerfEmbedder,
|
||||
NerfGLUBlock,
|
||||
NerfFinalLayer,
|
||||
NerfFinalLayerConv,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaRadianceParams(ChromaParams):
|
||||
patch_size: int
|
||||
nerf_hidden_size: int
|
||||
nerf_mlp_ratio: int
|
||||
nerf_depth: int
|
||||
nerf_max_freqs: int
|
||||
# Setting nerf_tile_size to 0 disables tiling.
|
||||
nerf_tile_size: int
|
||||
# Currently one of linear (legacy) or conv.
|
||||
nerf_final_head_type: str
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
use_x0: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
if operations is None:
|
||||
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
params = ChromaRadianceParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.in_dim = params.in_dim
|
||||
self.out_dim = params.out_dim
|
||||
self.hidden_dim = params.hidden_dim
|
||||
self.n_layers = params.n_layers
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in_patch = operations.Conv2d(
|
||||
params.in_channels,
|
||||
params.hidden_size,
|
||||
kernel_size=params.patch_size,
|
||||
stride=params.patch_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
# set as nn identity for now, will overwrite it later.
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
in_dim=self.in_dim,
|
||||
hidden_dim=self.hidden_dim,
|
||||
out_dim=self.out_dim,
|
||||
n_layers=self.n_layers,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# pixel channel concat with DCT
|
||||
self.nerf_image_embedder = NerfEmbedder(
|
||||
in_channels=params.in_channels,
|
||||
hidden_size_input=params.nerf_hidden_size,
|
||||
max_freqs=params.nerf_max_freqs,
|
||||
dtype=params.nerf_embedder_dtype or dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.nerf_blocks = nn.ModuleList([
|
||||
NerfGLUBlock(
|
||||
hidden_size_s=params.hidden_size,
|
||||
hidden_size_x=params.nerf_hidden_size,
|
||||
mlp_ratio=params.nerf_mlp_ratio,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
) for _ in range(params.nerf_depth)
|
||||
])
|
||||
|
||||
if params.nerf_final_head_type == "linear":
|
||||
self.nerf_final_layer = NerfFinalLayer(
|
||||
params.nerf_hidden_size,
|
||||
out_channels=params.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
elif params.nerf_final_head_type == "conv":
|
||||
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
||||
params.nerf_hidden_size,
|
||||
out_channels=params.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
else:
|
||||
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
|
||||
raise ValueError(errstr)
|
||||
|
||||
self.skip_mmdit = []
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
return self.nerf_final_layer
|
||||
if self.params.nerf_final_head_type == "conv":
|
||||
return self.nerf_final_layer_conv
|
||||
# Impossible to get here as we raise an error on unexpected types on initialization.
|
||||
raise NotImplementedError
|
||||
|
||||
def img_in(self, img: Tensor) -> Tensor:
|
||||
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
||||
# flatten into a sequence for the transformer.
|
||||
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
||||
|
||||
def forward_nerf(
|
||||
self,
|
||||
img_orig: Tensor,
|
||||
img_out: Tensor,
|
||||
params: ChromaRadianceParams,
|
||||
) -> Tensor:
|
||||
B, C, H, W = img_orig.shape
|
||||
num_patches = img_out.shape[1]
|
||||
patch_size = params.patch_size
|
||||
|
||||
# Store the raw pixel values of each patch for the NeRF head later.
|
||||
# unfold creates patches: [B, C * P * P, NumPatches]
|
||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||
# the tile size.
|
||||
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
else:
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
# Pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
img_dct = block(img_dct, nerf_hidden)
|
||||
|
||||
# Reassemble the patches into the final image.
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
|
||||
# Reshape to combine with batch dimension for fold
|
||||
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
|
||||
img_dct = nn.functional.fold(
|
||||
img_dct,
|
||||
output_size=(H, W),
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
return self._nerf_final_layer(img_dct)
|
||||
|
||||
def forward_tiled_nerf(
|
||||
self,
|
||||
nerf_hidden: Tensor,
|
||||
nerf_pixels: Tensor,
|
||||
batch: int,
|
||||
channels: int,
|
||||
num_patches: int,
|
||||
patch_size: int,
|
||||
params: ChromaRadianceParams,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Processes the NeRF head in tiles to save memory.
|
||||
nerf_hidden has shape [B, L, D]
|
||||
nerf_pixels has shape [B, L, C * P * P]
|
||||
"""
|
||||
tile_size = params.nerf_tile_size
|
||||
output_tiles = []
|
||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||
for i in range(0, num_patches, tile_size):
|
||||
end = min(i + tile_size, num_patches)
|
||||
|
||||
# Slice the current tile from the input tensors
|
||||
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
|
||||
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
|
||||
# pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
|
||||
|
||||
output_tiles.append(img_dct_tile)
|
||||
|
||||
# Concatenate the processed tiles along the patch dimension
|
||||
return torch.cat(output_tiles, dim=0)
|
||||
|
||||
def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
|
||||
params = self.params
|
||||
if not overrides:
|
||||
return params
|
||||
params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
|
||||
nullable_keys = frozenset(("nerf_embedder_dtype",))
|
||||
bad_keys = tuple(k for k in overrides if k not in params_dict)
|
||||
if bad_keys:
|
||||
e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
raise ValueError(e)
|
||||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
raise ValueError(e)
|
||||
# At this point it's all valid keys and values so we can merge with the existing params.
|
||||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||
|
||||
# non zero during training to prevent 0 div
|
||||
eps = 0.0
|
||||
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
timestep: Tensor,
|
||||
context: Tensor,
|
||||
guidance: Optional[Tensor],
|
||||
control: Optional[dict]=None,
|
||||
transformer_options: dict={},
|
||||
**kwargs: dict,
|
||||
) -> Tensor:
|
||||
bs, c, h, w = x.shape
|
||||
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
if img.ndim != 4:
|
||||
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
|
||||
if context.ndim != 3:
|
||||
raise ValueError("Input txt tensors must have 3 dimensions.")
|
||||
|
||||
params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
|
||||
|
||||
h_len = (img.shape[-2] // self.patch_size)
|
||||
w_len = (img.shape[-1] // self.patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
|
||||
img_out = self.forward_orig(
|
||||
img,
|
||||
img_ids,
|
||||
context,
|
||||
txt_ids,
|
||||
timestep,
|
||||
guidance,
|
||||
control,
|
||||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
|
||||
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
# If x0 variant → v-pred, just return this instead
|
||||
if hasattr(self, "__x0__"):
|
||||
out = self._apply_x0_residual(out, img, timestep)
|
||||
return out
|
||||
|
||||
@ -26,16 +26,6 @@ from torch import nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
if name == "I":
|
||||
return nn.Identity()
|
||||
@ -186,6 +176,7 @@ class Attention(nn.Module):
|
||||
context=None,
|
||||
mask=None,
|
||||
rope_emb=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -194,7 +185,7 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
@ -556,6 +547,7 @@ class VideoAttn(nn.Module):
|
||||
context: Optional[torch.Tensor] = None,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for video attention.
|
||||
@ -581,6 +573,7 @@ class VideoAttn(nn.Module):
|
||||
context_M_B_D,
|
||||
crossattn_mask,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||
return x_T_H_W_B_D
|
||||
@ -675,6 +668,7 @@ class DITBuildingBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||
@ -712,6 +706,7 @@ class DITBuildingBlock(nn.Module):
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
context=None,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
elif self.block_type in ["cross_attn", "ca"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
@ -719,6 +714,7 @@ class DITBuildingBlock(nn.Module):
|
||||
context=crossattn_emb,
|
||||
crossattn_mask=crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||
@ -794,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
@ -803,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
return x
|
||||
|
||||
@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
return x * torch.sigmoid(x)
|
||||
# x * sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@ -27,6 +27,8 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask,
|
||||
fps,
|
||||
image_size,
|
||||
padding_mask,
|
||||
scalar_feature,
|
||||
data_type,
|
||||
latent_condition,
|
||||
latent_condition_sigma,
|
||||
condition_video_augment_sigma,
|
||||
**kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# crossattn_emb: torch.Tensor,
|
||||
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -482,6 +520,7 @@ class GeneralDIT(nn.Module):
|
||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
for _, block in self.blocks.items():
|
||||
assert (
|
||||
self.blocks["block0"].x_format == block.x_format
|
||||
@ -496,6 +535,7 @@ class GeneralDIT(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||
|
||||
@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
h_extrapolation_ratio: float = 1.0,
|
||||
w_extrapolation_ratio: float = 1.0,
|
||||
t_extrapolation_ratio: float = 1.0,
|
||||
enable_fps_modulation: bool = True,
|
||||
device=None,
|
||||
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||
):
|
||||
del kwargs
|
||||
super().__init__()
|
||||
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
||||
self.base_fps = base_fps
|
||||
self.max_h = len_h
|
||||
self.max_w = len_w
|
||||
self.enable_fps_modulation = enable_fps_modulation
|
||||
|
||||
dim = head_dim
|
||||
dim_h = dim // 6 * 2
|
||||
@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
||||
|
||||
B, T, H, W, _ = B_T_H_W_C
|
||||
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
|
||||
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
||||
assert (
|
||||
uniform_fps or B == 1 or T == 1
|
||||
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
||||
assert (
|
||||
H <= self.max_h and W <= self.max_w
|
||||
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
||||
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
||||
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
||||
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
|
||||
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
|
||||
|
||||
# apply sequence scaling in temporal dimension
|
||||
if fps is None: # image case
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
||||
if fps is None or self.enable_fps_modulation is False: # image case
|
||||
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
|
||||
else:
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||
|
||||
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
||||
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
||||
|
||||
886
comfy/ldm/cosmos/predict2.py
Normal file
886
comfy/ldm/cosmos/predict2.py
Normal file
@ -0,0 +1,886 @@
|
||||
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
import logging
|
||||
from typing import Callable, Optional, Tuple
|
||||
import math
|
||||
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
# ---------------------- Feed Forward Network -----------------------
|
||||
class GPT2FeedForward(nn.Module):
|
||||
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.activation = nn.GELU()
|
||||
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
|
||||
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self._layer_id = None
|
||||
self._dim = d_model
|
||||
self._hidden_dim = d_ff
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.layer1(x)
|
||||
|
||||
x = self.activation(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native implementation.
|
||||
|
||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
|
||||
attention, and rearranges the output back to the original format.
|
||||
|
||||
The input tensor names use the following dimension conventions:
|
||||
|
||||
- B: batch size
|
||||
- S: sequence length
|
||||
- H: number of attention heads
|
||||
- D: head dimension
|
||||
|
||||
Args:
|
||||
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
|
||||
Returns:
|
||||
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
|
||||
"""
|
||||
in_q_shape = q_B_S_H_D.shape
|
||||
in_k_shape = k_B_S_H_D.shape
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
A flexible attention module supporting both self-attention and cross-attention mechanisms.
|
||||
|
||||
This module implements a multi-head attention layer that can operate in either self-attention
|
||||
or cross-attention mode. The mode is determined by whether a context dimension is provided.
|
||||
The implementation uses scaled dot-product attention and supports optional bias terms and
|
||||
dropout regularization.
|
||||
|
||||
Args:
|
||||
query_dim (int): The dimensionality of the query vectors.
|
||||
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
|
||||
If None, the module operates in self-attention mode using query_dim. Default: None
|
||||
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
|
||||
head_dim (int, optional): The dimension of each attention head. Default: 64
|
||||
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
|
||||
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
|
||||
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
|
||||
|
||||
Examples:
|
||||
>>> # Self-attention with 512 dimensions and 8 heads
|
||||
>>> self_attn = Attention(query_dim=512)
|
||||
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
|
||||
>>> out = self_attn(x) # (32, 16, 512)
|
||||
|
||||
>>> # Cross-attention
|
||||
>>> cross_attn = Attention(query_dim=512, context_dim=256)
|
||||
>>> query = torch.randn(32, 16, 512)
|
||||
>>> context = torch.randn(32, 8, 256)
|
||||
>>> out = cross_attn(query, context) # (32, 16, 512)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
context_dim: Optional[int] = None,
|
||||
n_heads: int = 8,
|
||||
head_dim: int = 64,
|
||||
dropout: float = 0.0,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{n_heads} heads with a dimension of {head_dim}."
|
||||
)
|
||||
self.is_selfattn = context_dim is None # self attention
|
||||
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
inner_dim = head_dim * n_heads
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.query_dim = query_dim
|
||||
self.context_dim = context_dim
|
||||
|
||||
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_norm = nn.Identity()
|
||||
|
||||
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
||||
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
||||
|
||||
self.attn_op = torch_attention_op
|
||||
|
||||
self._query_dim = query_dim
|
||||
self._context_dim = context_dim
|
||||
self._inner_dim = inner_dim
|
||||
|
||||
def compute_qkv(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.q_proj(x)
|
||||
context = x if context is None else context
|
||||
k = self.k_proj(context)
|
||||
v = self.v_proj(context)
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
def apply_norm_and_rotary_pos_emb(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
v = self.v_norm(v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
q = apply_rotary_pos_emb(q, rope_emb)
|
||||
k = apply_rotary_pos_emb(k, rope_emb)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
|
||||
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
|
||||
timesteps = timesteps_B_T.flatten().float()
|
||||
half_dim = self.num_channels // 2
|
||||
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - 0.0)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
sin_emb = torch.sin(emb)
|
||||
cos_emb = torch.cos(emb)
|
||||
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||
|
||||
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||
)
|
||||
self.in_dim = in_features
|
||||
self.out_dim = out_features
|
||||
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
|
||||
self.activation = nn.SiLU()
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if use_adaln_lora:
|
||||
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
|
||||
else:
|
||||
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
emb = self.linear_1(sample)
|
||||
emb = self.activation(emb)
|
||||
emb = self.linear_2(emb)
|
||||
|
||||
if self.use_adaln_lora:
|
||||
adaln_lora_B_T_3D = emb
|
||||
emb_B_T_D = sample
|
||||
else:
|
||||
adaln_lora_B_T_3D = None
|
||||
emb_B_T_D = emb
|
||||
|
||||
return emb_B_T_D, adaln_lora_B_T_3D
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||
and embedding each patch into a vector of size `out_channels`.
|
||||
|
||||
Parameters:
|
||||
- spatial_patch_size (int): The size of each spatial patch.
|
||||
- temporal_patch_size (int): The size of each temporal patch.
|
||||
- in_channels (int): Number of input channels. Default: 3.
|
||||
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spatial_patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 768,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
Rearrange(
|
||||
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||
r=temporal_patch_size,
|
||||
m=spatial_patch_size,
|
||||
n=spatial_patch_size,
|
||||
),
|
||||
operations.Linear(
|
||||
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
|
||||
),
|
||||
)
|
||||
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the PatchEmbed module.
|
||||
|
||||
Parameters:
|
||||
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||
B is the batch size,
|
||||
C is the number of channels,
|
||||
T is the temporal dimension,
|
||||
H is the height, and
|
||||
W is the width of the input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||
"""
|
||||
assert x.dim() == 5
|
||||
_, _, T, H, W = x.shape
|
||||
assert (
|
||||
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
|
||||
assert T % self.temporal_patch_size == 0
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of video DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
spatial_patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
out_channels: int,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.n_adaln_chunks = 2
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
if use_adaln_lora:
|
||||
self.adaln_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.adaln_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
shift_B_T_D, scale_B_T_D = (
|
||||
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
||||
).chunk(2, dim=-1)
|
||||
else:
|
||||
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||
|
||||
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
|
||||
scale_B_T_D, "b t d -> b t 1 1 d"
|
||||
)
|
||||
|
||||
def _fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
_norm_layer: nn.Module,
|
||||
_scale_B_T_1_1_D: torch.Tensor,
|
||||
_shift_B_T_1_1_D: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||
|
||||
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
|
||||
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
|
||||
return x_B_T_H_W_O
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""
|
||||
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
|
||||
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
|
||||
|
||||
Parameters:
|
||||
x_dim (int): Dimension of input features
|
||||
context_dim (int): Dimension of context features for cross-attention
|
||||
num_heads (int): Number of attention heads
|
||||
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
|
||||
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
|
||||
|
||||
The block applies the following sequence:
|
||||
1. Self-attention with AdaLN modulation
|
||||
2. Cross-attention with AdaLN modulation
|
||||
3. MLP with AdaLN modulation
|
||||
|
||||
Each component uses skip connections and layer normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_dim: int,
|
||||
context_dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.x_dim = x_dim
|
||||
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.cross_attn = Attention(
|
||||
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if self.use_adaln_lora:
|
||||
self.adaln_modulation_self_attn = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
self.adaln_modulation_cross_attn = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
self.adaln_modulation_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
if self.use_adaln_lora:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
||||
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
||||
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
else:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
||||
|
||||
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
|
||||
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
B, T, H, W, D = x_B_T_H_W_D.shape
|
||||
|
||||
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
|
||||
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_self_attn,
|
||||
scale_self_attn_B_T_1_1_D,
|
||||
shift_self_attn_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
layer_norm_cross_attn: Callable,
|
||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
_normalized_x_B_T_H_W_D = _fn(
|
||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
return _result_B_T_H_W_D
|
||||
|
||||
result_B_T_H_W_D = _x_fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_cross_attn,
|
||||
scale_cross_attn_B_T_1_1_D,
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_mlp,
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
class MiniTrainDIT(nn.Module):
|
||||
"""
|
||||
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
|
||||
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||
|
||||
Args:
|
||||
max_img_h (int): Maximum height of the input images.
|
||||
max_img_w (int): Maximum width of the input images.
|
||||
max_frames (int): Maximum number of frames in the video sequence.
|
||||
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||
out_channels (int): Number of output channels.
|
||||
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||
model_channels (int): Base number of channels used throughout the model.
|
||||
num_blocks (int): Number of transformer blocks.
|
||||
num_heads (int): Number of heads in the multi-head attention layers.
|
||||
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||
pos_emb_cls (str): Type of positional embeddings.
|
||||
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||
min_fps (int): Minimum frames per second.
|
||||
max_fps (int): Maximum frames per second.
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_img_h: int,
|
||||
max_img_w: int,
|
||||
max_frames: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
patch_spatial: int, # tuple,
|
||||
patch_temporal: int,
|
||||
concat_padding_mask: bool = True,
|
||||
# attention settings
|
||||
model_channels: int = 768,
|
||||
num_blocks: int = 10,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
# cross attention settings
|
||||
crossattn_emb_channels: int = 1024,
|
||||
# positional embedding settings
|
||||
pos_emb_cls: str = "sincos",
|
||||
pos_emb_learnable: bool = False,
|
||||
pos_emb_interpolation: str = "crop",
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 30,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
rope_h_extrapolation_ratio: float = 1.0,
|
||||
rope_w_extrapolation_ratio: float = 1.0,
|
||||
rope_t_extrapolation_ratio: float = 1.0,
|
||||
extra_per_block_abs_pos_emb: bool = False,
|
||||
extra_h_extrapolation_ratio: float = 1.0,
|
||||
extra_w_extrapolation_ratio: float = 1.0,
|
||||
extra_t_extrapolation_ratio: float = 1.0,
|
||||
rope_enable_fps_modulation: bool = True,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.max_img_h = max_img_h
|
||||
self.max_img_w = max_img_w
|
||||
self.max_frames = max_frames
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_spatial = patch_spatial
|
||||
self.patch_temporal = patch_temporal
|
||||
self.num_heads = num_heads
|
||||
self.num_blocks = num_blocks
|
||||
self.model_channels = model_channels
|
||||
self.concat_padding_mask = concat_padding_mask
|
||||
# positional embedding settings
|
||||
self.pos_emb_cls = pos_emb_cls
|
||||
self.pos_emb_learnable = pos_emb_learnable
|
||||
self.pos_emb_interpolation = pos_emb_interpolation
|
||||
self.min_fps = min_fps
|
||||
self.max_fps = max_fps
|
||||
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
||||
|
||||
self.build_pos_embed(device=device, dtype=dtype)
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
self.t_embedder = nn.Sequential(
|
||||
Timesteps(model_channels),
|
||||
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
|
||||
)
|
||||
|
||||
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||
self.x_embedder = PatchEmbed(
|
||||
spatial_patch_size=patch_spatial,
|
||||
temporal_patch_size=patch_temporal,
|
||||
in_channels=in_channels,
|
||||
out_channels=model_channels,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
x_dim=model_channels,
|
||||
context_dim=crossattn_emb_channels,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
use_adaln_lora=use_adaln_lora,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size=self.model_channels,
|
||||
spatial_patch_size=self.patch_spatial,
|
||||
temporal_patch_size=self.patch_temporal,
|
||||
out_channels=self.out_channels,
|
||||
use_adaln_lora=self.use_adaln_lora,
|
||||
adaln_lora_dim=self.adaln_lora_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
def build_pos_embed(self, device=None, dtype=None) -> None:
|
||||
if self.pos_emb_cls == "rope3d":
|
||||
cls_type = VideoRopePosition3DEmb
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||
|
||||
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
kwargs = dict(
|
||||
model_channels=self.model_channels,
|
||||
len_h=self.max_img_h // self.patch_spatial,
|
||||
len_w=self.max_img_w // self.patch_spatial,
|
||||
len_t=self.max_frames // self.patch_temporal,
|
||||
max_fps=self.max_fps,
|
||||
min_fps=self.min_fps,
|
||||
is_learnable=self.pos_emb_learnable,
|
||||
interpolation=self.pos_emb_interpolation,
|
||||
head_dim=self.model_channels // self.num_heads,
|
||||
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||
enable_fps_modulation=self.rope_enable_fps_modulation,
|
||||
device=device,
|
||||
)
|
||||
self.pos_embedder = cls_type(
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||
kwargs["device"] = device
|
||||
kwargs["dtype"] = dtype
|
||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
def prepare_embedded_sequence(
|
||||
self,
|
||||
x_B_C_T_H_W: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||
|
||||
Args:
|
||||
x_B_C_T_H_W (torch.Tensor): video
|
||||
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||
If None, a default value (`self.base_fps`) will be used.
|
||||
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||
|
||||
Notes:
|
||||
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||
the `self.pos_embedder` with the shape [T, H, W].
|
||||
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||
`self.pos_embedder` with the fps tensor.
|
||||
- Otherwise, the positional embeddings are generated without considering fps.
|
||||
"""
|
||||
if self.concat_padding_mask:
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||
else:
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
x_B_C_T_H_W = torch.cat(
|
||||
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||
)
|
||||
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||
else:
|
||||
extra_pos_emb = None
|
||||
|
||||
if "rope" in self.pos_emb_cls.lower():
|
||||
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||
|
||||
return x_B_T_H_W_D, None, extra_pos_emb
|
||||
|
||||
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
|
||||
x_B_C_Tt_Hp_Wp = rearrange(
|
||||
x_B_T_H_W_M,
|
||||
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||
p1=self.patch_spatial,
|
||||
p2=self.patch_spatial,
|
||||
t=self.patch_temporal,
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
x_B_C_T_H_W = x
|
||||
timesteps_B_T = timesteps
|
||||
crossattn_emb = context
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||
timesteps: (B, ) tensor of timesteps
|
||||
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||
"""
|
||||
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||
x_B_C_T_H_W,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
if timesteps_B_T.ndim == 1:
|
||||
timesteps_B_T = timesteps_B_T.unsqueeze(1)
|
||||
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
|
||||
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
|
||||
|
||||
# for logging purpose
|
||||
affline_scale_log_info = {}
|
||||
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
|
||||
self.affline_scale_log_info = affline_scale_log_info
|
||||
self.affline_emb = t_embedding_B_T_D
|
||||
self.crossattn_emb = crossattn_emb
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
assert (
|
||||
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
|
||||
|
||||
block_kwargs = {
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
t_embedding_B_T_D,
|
||||
crossattn_emb,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
@ -121,6 +121,11 @@ class ControlNetFlux(Flux):
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
else:
|
||||
y = y[:, :self.params.vec_in_dim]
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
@ -174,7 +179,7 @@ class ControlNetFlux(Flux):
|
||||
out["output"] = out_output[:self.main_model_single]
|
||||
return out
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
|
||||
@ -48,15 +48,44 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
return embedding
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
class YakMLP(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
if yak_mlp:
|
||||
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
if mlp_silu_act:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
@ -80,14 +109,14 @@ class QKNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -98,11 +127,11 @@ class ModulationOut:
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
if vec.ndim == 2:
|
||||
@ -118,7 +147,7 @@ class Modulation(nn.Module):
|
||||
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
if modulation_dims is None:
|
||||
if m_add is not None:
|
||||
return tensor * m_mult + m_add
|
||||
return torch.addcmul(m_add, tensor, m_mult)
|
||||
else:
|
||||
return tensor * m_mult
|
||||
else:
|
||||
@ -129,77 +158,107 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
return tensor
|
||||
|
||||
|
||||
class SiLUActivation(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return self.gate_fn(x1) * x2
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.modulation = modulation
|
||||
|
||||
if self.modulation:
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if self.modulation:
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
del img_modulated
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del img_qkv
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
del img_attn
|
||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
del txt_attn
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
@ -220,6 +279,10 @@ class SingleStreamBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
modulation=True,
|
||||
mlp_silu_act=False,
|
||||
bias=True,
|
||||
yak_mlp=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
@ -231,30 +294,55 @@ class SingleStreamBlock(nn.Module):
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||
self.yak_mlp = yak_mlp
|
||||
if mlp_silu_act:
|
||||
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||
self.mlp_act = SiLUActivation()
|
||||
else:
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
if self.yak_mlp:
|
||||
self.mlp_hidden_dim_first *= 2
|
||||
self.mlp_act = nn.SiLU()
|
||||
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
if modulation:
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.modulation = None
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||
if self.modulation:
|
||||
mod, _ = self.modulation(vec)
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
else:
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
@ -262,11 +350,11 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||
if vec.ndim == 2:
|
||||
|
||||
@ -6,18 +6,11 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
q, k = apply_rope(q, k, pe)
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
@ -35,11 +28,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@ -6,6 +6,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@ -14,6 +15,8 @@ from .layers import (
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@ -32,6 +35,14 @@ class FluxParams:
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
txt_ids_dims: list
|
||||
global_modulation: bool = False
|
||||
mlp_silu_act: bool = False
|
||||
ops_bias: bool = True
|
||||
default_ref_method: str = "offset"
|
||||
ref_index_scale: float = 1.0
|
||||
yak_mlp: bool = False
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@ -57,13 +68,22 @@ class Flux(nn.Module):
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
@ -72,6 +92,10 @@ class Flux(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=params.global_modulation is False,
|
||||
mlp_silu_act=params.mlp_silu_act,
|
||||
proj_bias=params.ops_bias,
|
||||
yak_mlp=params.yak_mlp,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -80,13 +104,30 @@ class Flux(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if params.global_modulation:
|
||||
self.double_stream_modulation_img = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.double_stream_modulation_txt = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.single_stream_modulation = Modulation(
|
||||
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@ -101,6 +142,8 @@ class Flux(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@ -112,9 +155,27 @@ class Flux(nn.Module):
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
if self.vector_in is not None:
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
if self.txt_norm is not None:
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
img = out["img"]
|
||||
txt = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
txt_ids = out["txt_ids"]
|
||||
|
||||
if img_ids is not None:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@ -122,7 +183,10 @@ class Flux(nn.Module):
|
||||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -130,14 +194,16 @@ class Flux(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@ -146,62 +212,142 @@ class Flux(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
img[:, :add.shape[1]] += add
|
||||
|
||||
if img.dtype == torch.float16:
|
||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
steps_h = h_len
|
||||
steps_w = w_len
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
index += rope_options.get("shift_t", 0.0)
|
||||
h_offset += rope_options.get("shift_y", 0.0)
|
||||
w_offset += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h_orig, w_orig = x.shape
|
||||
patch_size = self.patch_size
|
||||
|
||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img_tokens = img.shape[1]
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method == "index":
|
||||
index += self.params.ref_index_scale
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif ref_latents_method == "uxo":
|
||||
index = 0
|
||||
h_offset = h_len * patch_size + h
|
||||
w_offset = w_len * patch_size + w
|
||||
h += ref.shape[-2]
|
||||
w += ref.shape[-1]
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
if len(self.params.txt_ids_dims) > 0:
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
transformer_options={},
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
transformer_options={},
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
transformer_options=transformer_options,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
crop_y=args["num_tokens"],
|
||||
transformer_options=args["transformer_options"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
transformer_options=transformer_options,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
@ -71,8 +72,8 @@ class TimestepEmbed(nn.Module):
|
||||
return t_emb
|
||||
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
|
||||
|
||||
|
||||
class HiDreamAttnProcessor_flashattn:
|
||||
@ -85,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
@ -132,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
|
||||
query = torch.cat([query_1, query_2], dim=-1)
|
||||
key = torch.cat([key_1, key_2], dim=-1)
|
||||
|
||||
hidden_states = attention(query, key, value)
|
||||
hidden_states = attention(query, key, value, transformer_options=transformer_options)
|
||||
|
||||
if not attn.single:
|
||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||
@ -198,6 +200,7 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks: torch.FloatTensor = None,
|
||||
norm_text_tokens: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
@ -205,6 +208,7 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@ -405,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||
@ -418,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
|
||||
@ -482,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||
@ -499,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
image_tokens_masks,
|
||||
norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
@ -549,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
return self.block(
|
||||
image_tokens,
|
||||
@ -556,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens,
|
||||
adaln_input,
|
||||
rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@ -692,7 +701,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
image_cond=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
@ -769,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens = cur_encoder_hidden_states,
|
||||
adaln_input = adaln_input,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
block_id += 1
|
||||
@ -792,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens=None,
|
||||
adaln_input=adaln_input,
|
||||
rope=rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
@ -91,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@ -107,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
@ -118,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
|
||||
@ -4,81 +4,458 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from typing import Union, Tuple, List, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
# manually create the pointer vector
|
||||
assert src.size(0) == batch.numel()
|
||||
|
||||
return xyz, grid_size, length
|
||||
batch_size = int(batch.max()) + 1
|
||||
deg = src.new_zeros(batch_size, dtype = torch.long)
|
||||
|
||||
deg.scatter_add_(0, batch, torch.ones_like(batch))
|
||||
|
||||
ptr_vec = deg.new_zeros(batch_size + 1)
|
||||
torch.cumsum(deg, 0, out=ptr_vec[1:])
|
||||
|
||||
#return fps_sampling(src, ptr_vec, ratio)
|
||||
sampled_indicies = []
|
||||
|
||||
for b in range(batch_size):
|
||||
# start and the end of each batch
|
||||
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
|
||||
# points from the point cloud
|
||||
points = src[start:end]
|
||||
|
||||
num_points = points.size(0)
|
||||
num_samples = max(1, math.ceil(num_points * sampling_ratio))
|
||||
|
||||
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
|
||||
distances = torch.full((num_points,), float("inf"), device = src.device)
|
||||
|
||||
# select a random start point
|
||||
if start_random:
|
||||
farthest = torch.randint(0, num_points, (1,), device = src.device)
|
||||
else:
|
||||
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
|
||||
|
||||
for i in range(num_samples):
|
||||
selected[i] = farthest
|
||||
centroid = points[farthest].squeeze(0)
|
||||
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
|
||||
distances = torch.minimum(distances, dist)
|
||||
farthest = torch.argmax(distances)
|
||||
|
||||
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||
|
||||
return torch.cat(sampled_indicies, dim = 0)
|
||||
class PointCrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
num_latents: int,
|
||||
downsample_ratio: float,
|
||||
pc_size: int,
|
||||
pc_sharpedge_size: int,
|
||||
point_feats: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
layers: int,
|
||||
fourier_embedder,
|
||||
normal_pe: bool = False,
|
||||
qkv_bias: bool = False,
|
||||
use_ln_post: bool = True,
|
||||
qk_norm: bool = True):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.fourier_embedder = fourier_embedder
|
||||
|
||||
self.pc_size = pc_size
|
||||
self.normal_pe = normal_pe
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.pc_sharpedge_size = pc_sharpedge_size
|
||||
self.num_latents = num_latents
|
||||
self.point_feats = point_feats
|
||||
|
||||
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||
|
||||
self.cross_attn = ResidualCrossAttentionBlock(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm
|
||||
)
|
||||
|
||||
self.self_attn = None
|
||||
if layers > 0:
|
||||
self.self_attn = Transformer(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm,
|
||||
layers = layers
|
||||
)
|
||||
|
||||
if use_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
else:
|
||||
self.ln_post = None
|
||||
|
||||
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
"""
|
||||
Subsample points randomly from the point cloud (input_pc)
|
||||
Further sample the subsampled points to get query_pc
|
||||
take the fourier embeddings for both input and query pc
|
||||
|
||||
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
|
||||
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
|
||||
More computationally efficient.
|
||||
|
||||
Features are additional information for each point in the cloud
|
||||
"""
|
||||
|
||||
B, _, D = point_cloud.shape
|
||||
|
||||
num_latents = int(self.num_latents)
|
||||
|
||||
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
||||
num_sharpedge_query = num_latents - num_random_query
|
||||
|
||||
# Split random and sharpedge surface points
|
||||
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||
|
||||
# assert statements
|
||||
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
||||
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
||||
|
||||
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
|
||||
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
|
||||
|
||||
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
|
||||
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
|
||||
|
||||
else:
|
||||
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
|
||||
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
# concat the random and sharpedges
|
||||
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
|
||||
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
|
||||
|
||||
query = self.fourier_embedder(query_pc)
|
||||
data = self.fourier_embedder(input_pc)
|
||||
|
||||
if self.point_feats > 0:
|
||||
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
|
||||
|
||||
input_random_surface_features, query_random_features = \
|
||||
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
|
||||
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = input_random_surface_features.dtype, device = point_cloud.device)
|
||||
|
||||
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = query_random_features.dtype, device = point_cloud.device)
|
||||
else:
|
||||
|
||||
input_sharpedge_surface_features, query_sharpedge_features = \
|
||||
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
|
||||
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
|
||||
|
||||
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
|
||||
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
|
||||
|
||||
if self.normal_pe:
|
||||
# apply the fourier embeddings on the first 3 dims (xyz)
|
||||
input_features_pe = self.fourier_embedder(input_features[..., :3])
|
||||
query_features_pe = self.fourier_embedder(query_features[..., :3])
|
||||
# replace the first 3 dims with the new PE ones
|
||||
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
|
||||
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
|
||||
|
||||
# concat at the channels dim
|
||||
query = torch.cat([query, query_features], dim = -1)
|
||||
data = torch.cat([data, input_features], dim = -1)
|
||||
|
||||
# don't return pc_info to avoid unnecessary memory usuage
|
||||
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
|
||||
|
||||
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
|
||||
|
||||
# apply projections
|
||||
query = self.input_proj(query)
|
||||
data = self.input_proj(data)
|
||||
|
||||
# apply cross attention between query and data
|
||||
latents = self.cross_attn(query, data)
|
||||
|
||||
if self.self_attn is not None:
|
||||
latents = self.self_attn(latents)
|
||||
|
||||
if self.ln_post is not None:
|
||||
latents = self.ln_post(latents)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
def subsample(self, pc, num_query, input_pc_size: int):
|
||||
|
||||
"""
|
||||
num_query: number of points to keep after FPS
|
||||
input_pc_size: number of points to select before FPS
|
||||
"""
|
||||
|
||||
B, _, D = pc.shape
|
||||
query_ratio = num_query / input_pc_size
|
||||
|
||||
# random subsampling of points inside the point cloud
|
||||
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
|
||||
input_pc = pc[:, idx_pc, :]
|
||||
|
||||
# flatten to allow applying fps across the whole batch
|
||||
flattent_input_pc = input_pc.view(B * input_pc_size, D)
|
||||
|
||||
# construct a batch_down tensor to tell fps
|
||||
# which points belong to which batch
|
||||
N_down = int(flattent_input_pc.shape[0] / B)
|
||||
batch_down = torch.arange(B).to(pc.device)
|
||||
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||
|
||||
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
|
||||
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
|
||||
|
||||
return query_pc, input_pc, idx_pc, idx_query
|
||||
|
||||
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
|
||||
|
||||
B = batch_size
|
||||
|
||||
input_surface_features = features[:, idx_pc, :]
|
||||
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
|
||||
query_features = flattent_input_features[idx_query].view(B, -1,
|
||||
flattent_input_features.shape[-1])
|
||||
|
||||
return input_surface_features, query_features
|
||||
|
||||
def normalize_mesh(mesh, scale = 0.9999):
|
||||
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||
|
||||
bbox = mesh.bounds
|
||||
center = (bbox[1] + bbox[0]) / 2
|
||||
|
||||
max_extent = (bbox[1] - bbox[0]).max()
|
||||
mesh.apply_translation(-center)
|
||||
mesh.apply_scale((2 * scale) / max_extent)
|
||||
|
||||
return mesh
|
||||
|
||||
def sample_pointcloud(mesh, num = 200000):
|
||||
""" Uniformly sample points from the surface of the mesh """
|
||||
|
||||
points, face_idx = mesh.sample(num, return_index = True)
|
||||
normals = mesh.face_normals[face_idx]
|
||||
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||
|
||||
def detect_sharp_edges(mesh, threshold=0.985):
|
||||
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||
|
||||
V, F = mesh.vertices, mesh.faces
|
||||
VN, FN = mesh.vertex_normals, mesh.face_normals
|
||||
|
||||
sharp_mask = np.ones(V.shape[0])
|
||||
for i in range(3):
|
||||
indices = F[:, i]
|
||||
alignment = np.einsum('ij,ij->i', VN[indices], FN)
|
||||
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
|
||||
sharp_mask[indices] = np.min(dot_stack, axis=-1)
|
||||
|
||||
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
|
||||
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
|
||||
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
|
||||
|
||||
return edge_a[sharp_edges], edge_b[sharp_edges]
|
||||
|
||||
|
||||
def sharp_sample_pointcloud(mesh, num = 16384):
|
||||
""" Sample points preferentially from sharp edges in the mesh. """
|
||||
|
||||
edge_a, edge_b = detect_sharp_edges(mesh)
|
||||
V, VN = mesh.vertices, mesh.vertex_normals
|
||||
|
||||
va, vb = V[edge_a], V[edge_b]
|
||||
na, nb = VN[edge_a], VN[edge_b]
|
||||
|
||||
edge_lengths = np.linalg.norm(vb - va, axis=-1)
|
||||
weights = edge_lengths / edge_lengths.sum()
|
||||
|
||||
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
|
||||
t = np.random.rand(num, 1)
|
||||
|
||||
samples = t * va[indices] + (1 - t) * vb[indices]
|
||||
normals = t * na[indices] + (1 - t) * nb[indices]
|
||||
|
||||
return samples.astype(np.float32), normals.astype(np.float32)
|
||||
|
||||
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
|
||||
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||
|
||||
import trimesh
|
||||
|
||||
try:
|
||||
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||
except Exception:
|
||||
mesh_full = trimesh.util.concatenate(mesh)
|
||||
|
||||
mesh_full = normalize_mesh(mesh_full)
|
||||
|
||||
faces = mesh_full.faces
|
||||
vertices = mesh_full.vertices
|
||||
origin_face_count = faces.shape[0]
|
||||
|
||||
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
|
||||
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
|
||||
|
||||
area_surface = mesh_surface.area
|
||||
area_fill = mesh_fill.area
|
||||
total_area = area_surface + area_fill
|
||||
|
||||
sample_num = 499712 // 2
|
||||
fill_ratio = area_fill / total_area if total_area > 0 else 0
|
||||
|
||||
num_fill = int(sample_num * fill_ratio)
|
||||
num_surface = sample_num - num_fill
|
||||
|
||||
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
|
||||
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
|
||||
|
||||
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
|
||||
|
||||
def assemble_tensor(points, normals, label=None):
|
||||
|
||||
data = torch.cat([points, normals], dim=1).half().to(device)
|
||||
|
||||
if label is not None:
|
||||
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
|
||||
data = torch.cat([data, label_tensor], dim=1)
|
||||
|
||||
return data
|
||||
|
||||
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
|
||||
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
|
||||
label = 0 if sharpedge_flag else None)
|
||||
|
||||
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
|
||||
label = 1 if sharpedge_flag else None)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
|
||||
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
|
||||
|
||||
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
|
||||
|
||||
return full
|
||||
|
||||
class SharpEdgeSurfaceLoader:
|
||||
""" Load mesh surface and sharp edge samples. """
|
||||
|
||||
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
|
||||
|
||||
self.num_uniform_points = num_uniform_points
|
||||
self.num_sharp_points = num_sharp_points
|
||||
self.total_points = num_uniform_points + num_sharp_points
|
||||
|
||||
def __call__(self, mesh_input, device = "cuda"):
|
||||
mesh = self._load_mesh(mesh_input)
|
||||
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
|
||||
|
||||
@staticmethod
|
||||
def _load_mesh(mesh_input):
|
||||
import trimesh
|
||||
|
||||
if isinstance(mesh_input, str):
|
||||
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
|
||||
else:
|
||||
mesh = mesh_input
|
||||
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
combined = None
|
||||
for obj in mesh.geometry.values():
|
||||
combined = obj if combined is None else combined + obj
|
||||
return combined
|
||||
|
||||
return mesh
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||
|
||||
# divide quant channels (8) into mean and log variance
|
||||
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
|
||||
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
|
||||
def sample(self):
|
||||
|
||||
eps = torch.randn_like(self.std)
|
||||
z = self.mean + eps * self.std
|
||||
|
||||
return z
|
||||
|
||||
################################################
|
||||
# Volume Decoder
|
||||
################################################
|
||||
|
||||
class VanillaVolumeDecoder():
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
|
||||
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
|
||||
|
||||
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
|
||||
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
|
||||
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
|
||||
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
||||
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
|
||||
chunk_queries = xyz[start: start + num_chunks, :]
|
||||
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||
logits = geo_decoder(queries = chunk_queries, latents = latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
grid_logits = torch.cat(batch_logits, dim = 1)
|
||||
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
@ -232,38 +607,41 @@ class MLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
n_data = None,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_data = n_data
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
|
||||
return out
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
|
||||
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = ops.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
|
||||
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
if not self.enable_ln_post:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
width=width,
|
||||
@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
class ShapeVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
self,
|
||||
*,
|
||||
num_latents: int = 4096,
|
||||
embed_dim: int = 64,
|
||||
width: int = 1024,
|
||||
heads: int = 16,
|
||||
num_decoder_layers: int = 16,
|
||||
num_encoder_layers: int = 8,
|
||||
pc_size: int = 81920,
|
||||
pc_sharpedge_size: int = 0,
|
||||
point_feats: int = 4,
|
||||
downsample_ratio: int = 20,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
drop_path_rate: float = 0.0,
|
||||
include_pi: bool = False,
|
||||
scale_factor: float = 1.0039506158752403,
|
||||
label_type: str = "binary",
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.encoder = PointCrossAttention(layers = num_encoder_layers,
|
||||
num_latents = num_latents,
|
||||
downsample_ratio = downsample_ratio,
|
||||
heads = heads,
|
||||
pc_size = pc_size,
|
||||
width = width,
|
||||
point_feats = point_feats,
|
||||
fourier_embedder = self.fourier_embedder,
|
||||
pc_sharpedge_size = pc_sharpedge_size)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, x):
|
||||
return None
|
||||
def encode(self, surface):
|
||||
|
||||
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||
latents = self.encoder(pc, feats)
|
||||
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
|
||||
|
||||
latents = posterior.sample()
|
||||
|
||||
return latents
|
||||
|
||||
659
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
659
comfy/ldm/hunyuan3dv2_1/hunyuandit.py
Normal file
@ -0,0 +1,659 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
if gate.device.type == "mps":
|
||||
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
|
||||
|
||||
return F.gelu(gate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, dim_out = None, mult: int = 4,
|
||||
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(act_fn)
|
||||
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class AddAuxLoss(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, loss):
|
||||
# do nothing in forward (no computation)
|
||||
ctx.requires_aux_loss = loss.requires_grad
|
||||
ctx.dtype = loss.dtype
|
||||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# add the aux loss gradients
|
||||
grad_loss = None
|
||||
# put the aux grad the same as the main grad loss
|
||||
# aux grad contributes equally
|
||||
if ctx.requires_aux_loss:
|
||||
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
|
||||
|
||||
return grad_output, grad_loss
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
|
||||
|
||||
super().__init__()
|
||||
self.top_k = num_experts_per_tok
|
||||
self.n_routed_experts = num_experts
|
||||
|
||||
self.alpha = aux_loss_alpha
|
||||
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# flatten hidden states
|
||||
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
|
||||
|
||||
# get logits and pass it to softmax
|
||||
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
|
||||
scores = logits.softmax(dim = -1)
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
|
||||
# used bincount instead of one hot encoding
|
||||
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
|
||||
ce = counts / topk_idx.numel() # normalized expert usage
|
||||
|
||||
# mean expert score
|
||||
Pi = scores_for_aux.mean(0)
|
||||
|
||||
# expert balance loss
|
||||
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = None
|
||||
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
class MoEBlock(nn.Module):
|
||||
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
|
||||
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.moe_top_k = moe_top_k
|
||||
self.num_experts = num_experts
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
|
||||
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
|
||||
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
|
||||
if self.training:
|
||||
|
||||
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
|
||||
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
|
||||
|
||||
for i, expert in enumerate(self.experts):
|
||||
tmp = expert(hidden_states[flat_topk_idx == i])
|
||||
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
||||
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
|
||||
y = y.view(*orig_shape)
|
||||
|
||||
y = AddAuxLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
|
||||
y = y + self.shared_experts(identity)
|
||||
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
|
||||
# no need for .numpy().cpu() here
|
||||
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
|
||||
token_idxs = idxs // self.moe_top_k
|
||||
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens)
|
||||
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
|
||||
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
|
||||
# + avoid dtype conversion
|
||||
expert_cache.index_add_(0, exp_token_idx, expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
|
||||
scale: float = 1.0, max_period: int = 10000):
|
||||
super().__init__()
|
||||
|
||||
self.num_channels = num_channels
|
||||
half_dim = num_channels // 2
|
||||
|
||||
# precompute the “inv_freq” vector once
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
half_dim, dtype=torch.float32
|
||||
) / (half_dim - downscale_freq_shift)
|
||||
|
||||
inv_freq = torch.exp(exponent)
|
||||
|
||||
# pad
|
||||
if num_channels % 2 == 1:
|
||||
# we’ll pad a zero at the end of the cos-half
|
||||
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
|
||||
|
||||
# register to buffer so it moves with the device
|
||||
self.register_buffer("inv_freq", inv_freq, persistent = False)
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps: torch.Tensor):
|
||||
|
||||
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
|
||||
|
||||
|
||||
# fused CUDA kernels for sin and cos
|
||||
sin_emb = x.sin()
|
||||
cos_emb = x.cos()
|
||||
|
||||
emb = torch.cat([sin_emb, cos_emb], dim = 1)
|
||||
|
||||
# scale factor
|
||||
if self.scale != 1.0:
|
||||
emb = emb * self.scale
|
||||
|
||||
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
|
||||
if emb.shape[1] > self.num_channels:
|
||||
emb = emb[:, :self.num_channels]
|
||||
|
||||
return emb
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
|
||||
nn.GELU(),
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
|
||||
|
||||
self.time_embed = Timesteps(hidden_size)
|
||||
|
||||
def forward(self, timesteps, condition):
|
||||
|
||||
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
|
||||
|
||||
if condition is not None:
|
||||
cond_embed = self.cond_proj(condition)
|
||||
timestep_embed = timestep_embed + cond_embed
|
||||
|
||||
time_conditioned = self.mlp(timestep_embed)
|
||||
|
||||
# for broadcasting with image tokens
|
||||
return time_conditioned.unsqueeze(1)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
|
||||
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.gelu(self.fc1(x)))
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.qdim // num_heads
|
||||
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, y):
|
||||
|
||||
b, s1, _ = x.shape
|
||||
_, s2, _ = y.shape
|
||||
|
||||
y = y.to(next(self.to_k.parameters()).dtype)
|
||||
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(y)
|
||||
v = self.to_v(y)
|
||||
|
||||
kv = torch.cat((k, v), dim=-1)
|
||||
split_size = kv.shape[-1] // self.num_heads // 2
|
||||
|
||||
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||
k, v = torch.split(kv, split_size, dim=-1)
|
||||
|
||||
q = q.view(b, s1, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s2, self.num_heads, self.head_dim)
|
||||
v = v.reshape(b, s2, self.num_heads * self.head_dim)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
x = optimized_attention(
|
||||
q.reshape(b, s1, self.num_heads * self.head_dim),
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
|
||||
return out
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias = True,
|
||||
qk_norm = False,
|
||||
norm_layer = nn.LayerNorm,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None,
|
||||
dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
if norm_layer == nn.LayerNorm:
|
||||
norm_layer = operations.LayerNorm
|
||||
else:
|
||||
norm_layer = operations.RMSNorm
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
|
||||
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, _ = x.shape
|
||||
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
qkv_combined = torch.cat((query, key, value), dim=-1)
|
||||
split_size = qkv_combined.shape[-1] // self.num_heads // 3
|
||||
|
||||
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
query = query.reshape(B, N, self.num_heads, self.head_dim)
|
||||
key = key.reshape(B, N, self.num_heads, self.head_dim)
|
||||
value = value.reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
query = self.q_norm(query)
|
||||
key = self.k_norm(key)
|
||||
|
||||
x = optimized_attention(
|
||||
query.reshape(B, N, self.num_heads * self.head_dim),
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
c_emb_size,
|
||||
num_heads,
|
||||
text_states_dim=1024,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm_layer=True,
|
||||
qkv_bias=True,
|
||||
skip_connection=True,
|
||||
timested_modulate=False,
|
||||
use_moe: bool = False,
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
operations = None,
|
||||
device = None, dtype = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
||||
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
self.timested_modulate = timested_modulate
|
||||
if self.timested_modulate:
|
||||
self.default_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
|
||||
)
|
||||
|
||||
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
|
||||
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
|
||||
if skip_connection:
|
||||
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
|
||||
else:
|
||||
self.skip_linear = None
|
||||
|
||||
self.use_moe = use_moe
|
||||
|
||||
if self.use_moe:
|
||||
self.moe = MoEBlock(
|
||||
hidden_size,
|
||||
num_experts = num_experts,
|
||||
moe_top_k = moe_top_k,
|
||||
dropout = 0.0,
|
||||
ff_inner_dim = int(hidden_size * 4.0),
|
||||
device = device, dtype = dtype,
|
||||
operations = operations
|
||||
)
|
||||
else:
|
||||
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
|
||||
|
||||
if self.skip_linear is not None:
|
||||
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
|
||||
hidden_states = self.skip_linear(combined)
|
||||
hidden_states = self.skip_norm(hidden_states)
|
||||
|
||||
# self attention
|
||||
if self.timested_modulate:
|
||||
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
|
||||
hidden_states = hidden_states + modulation_shift
|
||||
|
||||
self_attn_out = self.attn1(self.norm1(hidden_states))
|
||||
hidden_states = hidden_states + self_attn_out
|
||||
|
||||
# cross attention
|
||||
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
|
||||
|
||||
# MLP Layer
|
||||
mlp_input = self.norm3(hidden_states)
|
||||
|
||||
if self.use_moe:
|
||||
hidden_states = hidden_states + self.moe(mlp_input)
|
||||
else:
|
||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
|
||||
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
|
||||
super().__init__()
|
||||
|
||||
if use_fp16:
|
||||
eps = 1.0 / 65504
|
||||
else:
|
||||
eps = 1e-6
|
||||
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
|
||||
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm_final(x)
|
||||
x = x[:, 1:]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class HunYuanDiTPlain(nn.Module):
|
||||
|
||||
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 64,
|
||||
hidden_size: int = 2048,
|
||||
context_dim: int = 1024,
|
||||
depth: int = 21,
|
||||
num_heads: int = 16,
|
||||
qk_norm: bool = True,
|
||||
qkv_bias: bool = False,
|
||||
num_moe_layers: int = 6,
|
||||
guidance_cond_proj_dim = 2048,
|
||||
norm_type = 'layer',
|
||||
num_experts: int = 8,
|
||||
moe_top_k: int = 2,
|
||||
use_fp16: bool = False,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
|
||||
qk_norm = operations.RMSNorm
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
||||
|
||||
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
|
||||
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
text_states_dim=context_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_layer = norm,
|
||||
qk_norm_layer = qk_norm,
|
||||
skip_connection=layer > depth // 2,
|
||||
qkv_bias=qkv_bias,
|
||||
use_moe=True if depth - layer <= num_moe_layers else False,
|
||||
num_experts=num_experts,
|
||||
moe_top_k=moe_top_k,
|
||||
use_fp16 = use_fp16,
|
||||
device = device, dtype = dtype, operations = operations)
|
||||
for layer in range(depth)
|
||||
])
|
||||
|
||||
self.depth = depth
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
|
||||
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
|
||||
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
|
||||
|
||||
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
|
||||
x_embedded = self.x_embedder(x)
|
||||
|
||||
combined = torch.cat([time_embedded, x_embedded], dim=1)
|
||||
|
||||
def block_wrap(args):
|
||||
return block(
|
||||
args["x"],
|
||||
args["t"],
|
||||
args["cond"],
|
||||
skip_tensor=args.get("skip"),)
|
||||
|
||||
skip_stack = []
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for idx, block in enumerate(self.blocks):
|
||||
if idx <= self.depth // 2:
|
||||
skip_input = None
|
||||
else:
|
||||
skip_input = skip_stack.pop()
|
||||
|
||||
if ("block", idx) in blocks_replace:
|
||||
|
||||
combined = blocks_replace[("block", idx)](
|
||||
{
|
||||
"x": combined,
|
||||
"t": time_embedded,
|
||||
"cond": main_condition,
|
||||
"skip": skip_input,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
else:
|
||||
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
|
||||
|
||||
if idx < self.depth // 2:
|
||||
skip_stack.append(combined)
|
||||
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
@ -1,11 +1,11 @@
|
||||
#Based on Flux code because of weird hunyuan video code license.
|
||||
|
||||
import torch
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
@ -39,6 +39,11 @@ class HunyuanVideoParams:
|
||||
patch_size: list
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
meanflow_sum: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@ -77,13 +82,13 @@ class TokenRefinerBlock(nn.Module):
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn.qkv(norm_x)
|
||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
@ -114,14 +119,14 @@ class IndividualTokenRefiner(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
m = None
|
||||
if mask is not None:
|
||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||
m = m + m.transpose(2, 3)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, m)
|
||||
x = block(x, c, m, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
@ -149,17 +154,45 @@ class TokenRefiner(nn.Module):
|
||||
x,
|
||||
timesteps,
|
||||
mask,
|
||||
transformer_options={},
|
||||
):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
if x.dtype == torch.float16:
|
||||
c = x.float().sum(dim=1) / x.shape[1]
|
||||
else:
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
class ByT5Mapper(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
|
||||
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
|
||||
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
|
||||
self.use_res = use_res
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res:
|
||||
res = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x2 = self.act_fn(x)
|
||||
x2 = self.fc3(x2)
|
||||
if self.use_res:
|
||||
x2 = x2 + res
|
||||
return x2
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@ -168,11 +201,15 @@ class HunyuanVideo(nn.Module):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
self.use_cond_type_embedding = params.use_cond_type_embedding
|
||||
self.vision_in_dim = params.vision_in_dim
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@ -184,9 +221,13 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
@ -214,9 +255,38 @@ class HunyuanVideo(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
if params.byt5:
|
||||
self.byt5_in = ByT5Mapper(
|
||||
in_dim=1472,
|
||||
out_dim=2048,
|
||||
hidden_dim=2048,
|
||||
out_dim1=self.hidden_size,
|
||||
use_res=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
self.byt5_in = None
|
||||
|
||||
if params.meanflow:
|
||||
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.time_r_in = None
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# HunyuanVideo 1.5 specific modules
|
||||
if self.vision_in_dim is not None:
|
||||
from comfy.ldm.wan.model import MLPProj
|
||||
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
|
||||
else:
|
||||
self.vision_in = None
|
||||
if self.use_cond_type_embedding:
|
||||
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
|
||||
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
|
||||
else:
|
||||
self.cond_type_embedding = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
@ -225,10 +295,13 @@ class HunyuanVideo(nn.Module):
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
clip_fea=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
disable_time_r=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@ -239,6 +312,14 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
if (self.time_r_in is not None) and (not disable_time_r):
|
||||
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
|
||||
if len(w) > 0:
|
||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
ref_latent = self.img_in(ref_latent)
|
||||
@ -249,13 +330,17 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
if guiding_frame_index is not None:
|
||||
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
if self.vector_in is not None:
|
||||
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
||||
else:
|
||||
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
|
||||
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
||||
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
||||
modulation_dims_txt = [(0, None, 1)]
|
||||
else:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if self.vector_in is not None:
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
modulation_dims = None
|
||||
modulation_dims_txt = None
|
||||
|
||||
@ -266,7 +351,32 @@ class HunyuanVideo(nn.Module):
|
||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.cond_type_embedding is not None:
|
||||
self.cond_type_embedding.to(txt.device)
|
||||
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
|
||||
txt = txt + cond_emb.to(txt.dtype)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
|
||||
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
|
||||
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
|
||||
else:
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
if clip_fea is not None:
|
||||
txt_vision_states = self.vision_in(clip_fea)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
|
||||
txt_vision_states = txt_vision_states + cond_emb
|
||||
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@ -280,18 +390,21 @@ class HunyuanVideo(nn.Module):
|
||||
attn_mask = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -302,17 +415,20 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@ -327,12 +443,16 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
shape = initial_shape[-len(self.patch_size):]
|
||||
for i in range(len(shape)):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
if img.ndim == 8:
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||
else:
|
||||
img = img.permute(0, 3, 1, 4, 2, 5)
|
||||
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
|
||||
return img
|
||||
|
||||
def img_ids(self, x):
|
||||
@ -347,9 +467,30 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
def img_ids_2d(self, x):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
|
||||
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
|
||||
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
121
comfy/ldm/hunyuan_video/upsampler.py
Normal file
121
comfy/ldm/hunyuan_video/upsampler.py
Normal file
@ -0,0 +1,121 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management, model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
class SRModel3DV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int = 64,
|
||||
num_blocks: int = 6,
|
||||
global_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
|
||||
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
|
||||
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
|
||||
self.global_residual = bool(global_residual)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
y = self.in_conv(x)
|
||||
for blk in self.blocks:
|
||||
y = blk(y)
|
||||
y = self.out_conv(y)
|
||||
if self.global_residual and (y.shape == residual.shape):
|
||||
y = y + residual
|
||||
return y
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
z_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: tuple[int, ...],
|
||||
num_res_blocks: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.block_out_channels = block_out_channels
|
||||
self.z_channels = z_channels
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_shortcut=False,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Args:
|
||||
z: (B, C, T, H, W)
|
||||
target_shape: (H, W)
|
||||
"""
|
||||
# z to block_in
|
||||
repeats = self.block_out_channels[0] // (self.z_channels)
|
||||
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
||||
|
||||
# upsampling
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||
return out
|
||||
|
||||
UPSAMPLERS = {
|
||||
"720p": SRModel3DV2,
|
||||
"1080p": Upsampler,
|
||||
}
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = model_management.vae_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.dtype = model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
136
comfy/ldm/hunyuan_video/vae.py
Normal file
136
comfy/ldm/hunyuan_video/vae.py
Normal file
@ -0,0 +1,136 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class PixelShuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
|
||||
self.ratio = (in_dim << 2) // out_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h >> 1, w >> 1
|
||||
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
|
||||
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
|
||||
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
|
||||
|
||||
|
||||
class PixelUnshuffle2D(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
|
||||
self.scale = (out_dim << 2) // in_dim
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
h2, w2 = h << 1, w << 1
|
||||
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
|
||||
return y + r
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, downsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
b, c, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, h, w).mean(2)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, upsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=ops.Conv2d)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
|
||||
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, z):
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x)))
|
||||
313
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
313
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
@ -0,0 +1,313 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
import comfy.model_management
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
shape = (dim, 1, 1, 1)
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.empty(shape))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds, refiner_vae, op):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||
assert oc % fct == 0
|
||||
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
|
||||
self.refiner_vae = refiner_vae
|
||||
|
||||
self.tds = tds
|
||||
self.gs = fct * ic // oc
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tds else 1
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||
hf = torch.cat([hf, hf], dim=1)
|
||||
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
|
||||
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xf.shape
|
||||
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
if h.shape[2] == 0:
|
||||
return hf + xf
|
||||
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tus, refiner_vae, op):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
||||
self.refiner_vae = refiner_vae
|
||||
|
||||
self.tus = tus
|
||||
self.rp = fct * oc // ic
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tus else 1
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
nc = c // (2 * 2)
|
||||
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
hf = hf[:, : hf.shape[1] // 2]
|
||||
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
||||
b, c, f, ht, wd = xf.shape
|
||||
nc = c // (2 * 2)
|
||||
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
x = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = x.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.ffactor_temporal = ffactor_temporal
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
norm_op = Normalize
|
||||
|
||||
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
||||
|
||||
def forward(self, x):
|
||||
if not self.refiner_vae and x.shape[2] == 1:
|
||||
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
||||
|
||||
if self.refiner_vae:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.ffactor_temporal:
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
x1 = [ x1 ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||
del out
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||
|
||||
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
|
||||
|
||||
if self.refiner_vae:
|
||||
out = self.regul(out)[0]
|
||||
|
||||
return out
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
norm_op = Normalize
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
depth_temporal = (ffactor_temporal >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
if self.refiner_vae:
|
||||
x = torch.split(x, 2, dim=2)
|
||||
else:
|
||||
x = [ x ]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
x1 = [ F.silu(self.norm_out(x1)) ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
del x
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
if not self.refiner_vae:
|
||||
if z.shape[-3] == 1:
|
||||
out = out[:, :, -1:]
|
||||
|
||||
return out
|
||||
|
||||
413
comfy/ldm/kandinsky5/model.py
Normal file
413
comfy/ldm/kandinsky5/model.py
Normal file
@ -0,0 +1,413 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
def attention(q, k, v, heads, transformer_options={}):
|
||||
return optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
heads=heads,
|
||||
skip_reshape=True,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return torch.addcmul(x, gate, out)
|
||||
|
||||
def get_shift_scale_gate(params):
|
||||
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
||||
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||
super().__init__()
|
||||
assert model_dim % 2 == 0
|
||||
self.model_dim = model_dim
|
||||
self.max_period = max_period
|
||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, text_embed):
|
||||
text_embed = self.in_layer(text_embed)
|
||||
return self.norm(text_embed).type_as(text_embed)
|
||||
|
||||
|
||||
class VisualEmbeddings(nn.Module):
|
||||
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||
B, T, H, W, dim = x.shape
|
||||
pt, ph, pw = self.patch_size
|
||||
|
||||
x = x.view(
|
||||
B,
|
||||
T // pt, pt,
|
||||
H // ph, ph,
|
||||
W // pw, pw,
|
||||
dim,
|
||||
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||
|
||||
return self.in_layer(x)
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
self.head_dim = head_dim
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 2
|
||||
|
||||
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(norm_fn(result), freqs)
|
||||
|
||||
def _forward(self, x, freqs, transformer_options={}):
|
||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||
def process_chunks(proj_fn, norm_fn):
|
||||
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||
chunks = []
|
||||
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
q = process_chunks(self.to_query, self.query_norm)
|
||||
k = process_chunks(self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||
else:
|
||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
def get_qkv(self, x, context):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, context, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, context)
|
||||
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 4
|
||||
|
||||
def _forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
def _forward_chunked(self, x):
|
||||
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
output_chunks = []
|
||||
for chunk in chunks:
|
||||
output_chunks.append(self._forward(chunk))
|
||||
return torch.cat(output_chunks, dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, visual_embed, time_embed):
|
||||
B, T, H, W, _ = visual_embed.shape
|
||||
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||
scale = scale[:, None, None, None, :]
|
||||
shift = shift[:, None, None, None, :]
|
||||
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
||||
x = x.view(
|
||||
B, T, H, W,
|
||||
out_dim,
|
||||
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
||||
)
|
||||
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||
out = self.feed_forward(out)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||
# self attention
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# cross attention
|
||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# feed forward
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
||||
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
||||
dtype=None, device=None, operations=None, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
head_dim = sum(axes_dims)
|
||||
self.rope_scale_factor = rope_scale_factor
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_embed_dim = visual_embed_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
||||
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
||||
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
||||
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.text_transformer_blocks = nn.ModuleList(
|
||||
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
||||
)
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList(
|
||||
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
||||
)
|
||||
|
||||
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
||||
|
||||
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
||||
steps = seq_len if steps is None else steps
|
||||
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
||||
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
||||
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
else:
|
||||
rope_scale_factor = self.rope_scale_factor
|
||||
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||
if h * w >= 14080:
|
||||
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||
|
||||
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
context = self.text_embeddings(context)
|
||||
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||
|
||||
for block in self.text_transformer_blocks:
|
||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||
else:
|
||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
original_dims = x.ndim
|
||||
if original_dims == 4:
|
||||
x = x.unsqueeze(2)
|
||||
bs, c, t_len, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if time_dim_replace is not None:
|
||||
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||
|
||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||
if original_dims == 4:
|
||||
out = out.squeeze(2)
|
||||
return out
|
||||
|
||||
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||
@ -1,13 +1,13 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@ -237,20 +237,6 @@ class FeedForward(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@ -261,8 +247,8 @@ class CrossAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
@ -270,7 +256,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
@ -280,13 +266,13 @@ class CrossAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@ -302,15 +288,20 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
x.addcmul_(self.ff(y), gate_mlp)
|
||||
|
||||
return x
|
||||
|
||||
@ -326,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
dtype = torch.float32
|
||||
device = indices_grid.device
|
||||
|
||||
# Get fractional positions and compute frequency indices
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
# Compute frequencies and apply cos/sin
|
||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
# Pad if dim is not divisible by 6
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
padding_size = dim % 6
|
||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||
|
||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
|
||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||
freqs_cis = torch.stack([
|
||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
@ -420,6 +405,13 @@ class LTXVModel(torch.nn.Module):
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
@ -471,10 +463,10 @@ class LTXVModel(torch.nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@ -482,7 +474,8 @@ class LTXVModel(torch.nn.Module):
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
@ -492,7 +485,7 @@ class LTXVModel(torch.nn.Module):
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = torch.addcmul(x, x, scale).add_(shift)
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
|
||||
@ -973,7 +973,7 @@ class VideoVAE(nn.Module):
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=self.timestep_conditioning,
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "reflect"),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
160
comfy/ldm/lumina/controlnet.py
Normal file
160
comfy/ldm/lumina/controlnet.py
Normal file
@ -0,0 +1,160 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .model import JointTransformerBlock
|
||||
|
||||
class ZImageControlTransformerBlock(JointTransformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
block_id=0,
|
||||
operation_settings=None,
|
||||
):
|
||||
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
return c_skip, c
|
||||
|
||||
class ZImage_Control(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3840,
|
||||
n_heads: int = 30,
|
||||
n_kv_heads: int = 30,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
n_control_layers=6,
|
||||
control_in_dim=16,
|
||||
additional_in_dim=0,
|
||||
broken=False,
|
||||
refiner_control=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.broken = broken
|
||||
self.additional_in_dim = additional_in_dim
|
||||
self.control_in_dim = control_in_dim
|
||||
n_refiner_layers = 2
|
||||
self.n_control_layers = n_control_layers
|
||||
self.control_layers = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
i,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=i,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for i in range(self.n_control_layers)
|
||||
]
|
||||
)
|
||||
|
||||
all_x_embedder = {}
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
self.refiner_control = refiner_control
|
||||
|
||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
if self.refiner_control:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=layer_id,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
pH = pW = patch_size
|
||||
B, C, H, W = control_context.shape
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
x_attn_mask = None
|
||||
if not self.refiner_control:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
|
||||
return control_context
|
||||
|
||||
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
if self.refiner_control:
|
||||
if self.broken:
|
||||
if layer_id == 0:
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if layer_id > 0:
|
||||
out = None
|
||||
for i in range(1, len(self.control_layers)):
|
||||
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if out is None:
|
||||
out = o
|
||||
|
||||
return (out, control_context)
|
||||
else:
|
||||
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
else:
|
||||
return (None, control_context)
|
||||
|
||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
@ -11,6 +11,8 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
@ -20,6 +22,10 @@ def modulate(x, scale):
|
||||
# Core NextDiT Model #
|
||||
#############################################################################
|
||||
|
||||
def clamp_fp16(x):
|
||||
if x.dtype == torch.float16:
|
||||
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
class JointAttention(nn.Module):
|
||||
"""Multi-head attention module."""
|
||||
@ -30,6 +36,7 @@ class JointAttention(nn.Module):
|
||||
n_heads: int,
|
||||
n_kv_heads: Optional[int],
|
||||
qk_norm: bool,
|
||||
out_bias: bool = False,
|
||||
operation_settings={},
|
||||
):
|
||||
"""
|
||||
@ -58,7 +65,7 @@ class JointAttention(nn.Module):
|
||||
self.out = operation_settings.get("operations").Linear(
|
||||
n_heads * self.head_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
bias=out_bias,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
@ -69,40 +76,12 @@ class JointAttention(nn.Module):
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency
|
||||
tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and
|
||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors
|
||||
contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
||||
exponentials.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
"""
|
||||
|
||||
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x_in.shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
transformer_options={},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@ -132,14 +111,13 @@ class JointAttention(nn.Module):
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis)
|
||||
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
return self.out(output)
|
||||
|
||||
@ -195,7 +173,7 @@ class FeedForward(nn.Module):
|
||||
|
||||
# @torch.compile
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
return clamp_fp16(F.silu(x1) * x3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
@ -213,6 +191,8 @@ class JointTransformerBlock(nn.Module):
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
z_image_modulation=False,
|
||||
attn_out_bias=False,
|
||||
operation_settings={},
|
||||
) -> None:
|
||||
"""
|
||||
@ -233,10 +213,10 @@ class JointTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
hidden_dim=dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
operation_settings=operation_settings,
|
||||
@ -250,16 +230,27 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
if z_image_modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 256),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -267,6 +258,7 @@ class JointTransformerBlock(nn.Module):
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the TransformerBlock.
|
||||
@ -285,25 +277,27 @@ class JointTransformerBlock(nn.Module):
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
self.attention(
|
||||
clamp_fp16(self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
)
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
clamp_fp16(self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
)
|
||||
))
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
x = x + self.attention_norm2(
|
||||
self.attention(
|
||||
clamp_fp16(self.attention(
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
)
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
@ -318,7 +312,7 @@ class FinalLayer(nn.Module):
|
||||
The final layer of NextDiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
||||
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
|
||||
super().__init__()
|
||||
self.norm_final = operation_settings.get("operations").LayerNorm(
|
||||
hidden_size,
|
||||
@ -335,10 +329,15 @@ class FinalLayer(nn.Module):
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
if z_image_modulation:
|
||||
min_mod = 256
|
||||
else:
|
||||
min_mod = 1024
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(hidden_size, 1024),
|
||||
min(hidden_size, min_mod),
|
||||
hidden_size,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
@ -368,12 +367,17 @@ class NextDiT(nn.Module):
|
||||
n_heads: int = 32,
|
||||
n_kv_heads: Optional[int] = None,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
ffn_dim_multiplier: float = 4.0,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = False,
|
||||
cap_feat_dim: int = 5120,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (1, 512, 512),
|
||||
rope_theta=10000.0,
|
||||
z_image_modulation=False,
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -385,6 +389,8 @@ class NextDiT(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.time_scale = time_scale
|
||||
self.pad_tokens_multiple = pad_tokens_multiple
|
||||
|
||||
self.x_embedder = operation_settings.get("operations").Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
@ -406,6 +412,7 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=z_image_modulation,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
@ -429,7 +436,7 @@ class NextDiT(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
@ -441,6 +448,31 @@ class NextDiT(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
self.clip_text_pooled_proj = None
|
||||
|
||||
if clip_text_dim is not None:
|
||||
self.clip_text_dim = clip_text_dim
|
||||
self.clip_text_pooled_proj = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
clip_text_dim,
|
||||
clip_text_dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.time_text_embed = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024) + clip_text_dim,
|
||||
min(dim, 1024),
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
@ -452,18 +484,24 @@ class NextDiT(nn.Module):
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
z_image_modulation=z_image_modulation,
|
||||
attn_out_bias=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
@ -493,105 +531,79 @@ class NextDiT(nn.Module):
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
orig_x = x
|
||||
|
||||
if cap_mask is not None:
|
||||
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
||||
else:
|
||||
l_effective_cap_len = [num_tokens] * bsz
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
|
||||
if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
||||
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
||||
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
max_seq_len = max(
|
||||
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
||||
)
|
||||
max_cap_len = max(l_effective_cap_len)
|
||||
max_img_len = max(l_effective_img_len)
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
assert H_tokens * W_tokens == img_len
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
|
||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
# build freqs_cis for cap and image individually
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
# cap_freqs_cis_shape[1] = max_cap_len
|
||||
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
|
||||
# refine image
|
||||
flat_x = []
|
||||
for i in range(bsz):
|
||||
img = x[i]
|
||||
C, H, W = img.size()
|
||||
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
||||
flat_x.append(img)
|
||||
x = flat_x
|
||||
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
||||
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
||||
for i in range(bsz):
|
||||
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
||||
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
||||
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
||||
|
||||
if cap_mask is not None:
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
||||
else:
|
||||
mask = None
|
||||
|
||||
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
|
||||
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
||||
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
||||
padded_img_mask = None
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
mask = None
|
||||
img_sizes = [(H, W)] * bsz
|
||||
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
@ -603,20 +615,41 @@ class NextDiT(nn.Module):
|
||||
y: (N,) tensor of text tokens/features
|
||||
"""
|
||||
|
||||
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input)
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
transformer_options["block_type"] = "double"
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||
img = self.final_layer(img, adaln_input)
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
|
||||
return -x
|
||||
return -img
|
||||
|
||||
|
||||
0
comfy/ldm/mmaudio/vae/__init__.py
Normal file
0
comfy/ldm/mmaudio/vae/__init__.py
Normal file
120
comfy/ldm/mmaudio/vae/activations.py
Normal file
120
comfy/ldm/mmaudio/vae/activations.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
import comfy.model_management
|
||||
|
||||
class Snake(nn.Module):
|
||||
'''
|
||||
Implementation of a sine-based periodic activation function
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter
|
||||
References:
|
||||
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snake(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha: trainable parameter
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(Snake, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
else:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
'''
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
self.beta = Parameter(torch.empty(in_features))
|
||||
else:
|
||||
self.alpha = Parameter(torch.empty(in_features))
|
||||
self.beta = Parameter(torch.empty(in_features))
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
157
comfy/ldm/mmaudio/vae/alias_free_torch.py
Normal file
157
comfy/ldm/mmaudio/vae/alias_free_torch.py
Normal file
@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import comfy.model_management
|
||||
|
||||
if 'sinc' in dir(torch):
|
||||
sinc = torch.sinc
|
||||
else:
|
||||
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/core.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def sinc(x: torch.Tensor):
|
||||
"""
|
||||
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||
"""
|
||||
return torch.where(x == 0,
|
||||
torch.tensor(1., device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x)
|
||||
|
||||
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = (kernel_size % 2 == 0)
|
||||
half_size = kernel_size // 2
|
||||
|
||||
#For kaiser window
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.:
|
||||
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
||||
else:
|
||||
beta = 0.
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
|
||||
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||
if even:
|
||||
time = (torch.arange(-half_size, half_size) + 0.5)
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
||||
# of the constant component in the input signal.
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride: int = 1,
|
||||
padding: bool = True,
|
||||
padding_mode: str = 'replicate',
|
||||
kernel_size: int = 12):
|
||||
# kernel_size should be even number for stylegan3 setup,
|
||||
# in this implementation, odd number is also possible.
|
||||
super().__init__()
|
||||
if cutoff < -0.:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = (kernel_size % 2 == 0)
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
#input [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||
mode=self.padding_mode)
|
||||
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
|
||||
stride=self.stride, groups=C)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left:-self.pad_right]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.lowpass(x)
|
||||
|
||||
return xx
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(self,
|
||||
activation,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12):
|
||||
super().__init__()
|
||||
self.up_ratio = up_ratio
|
||||
self.down_ratio = down_ratio
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
# x: [B,C,T]
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
|
||||
return x
|
||||
156
comfy/ldm/mmaudio/vae/autoencoder.py
Normal file
156
comfy/ldm/mmaudio/vae/autoencoder.py
Normal file
@ -0,0 +1,156 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .distributions import DiagonalGaussianDistribution
|
||||
from .vae import VAE_16k
|
||||
from .bigvgan import BigVGANVocoder
|
||||
import logging
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
except:
|
||||
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
|
||||
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes, norm_fn):
|
||||
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
|
||||
return output
|
||||
|
||||
class MelConverter(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sampling_rate: float,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
fmin: float,
|
||||
fmax: float,
|
||||
norm_fn,
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_rate = sampling_rate
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
# mel = librosa_mel_fn(sr=self.sampling_rate,
|
||||
# n_fft=self.n_fft,
|
||||
# n_mels=self.num_mels,
|
||||
# fmin=self.fmin,
|
||||
# fmax=self.fmax)
|
||||
# mel_basis = torch.from_numpy(mel).float()
|
||||
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
|
||||
hann_window = torch.hann_window(self.win_size)
|
||||
|
||||
self.register_buffer('mel_basis', mel_basis)
|
||||
self.register_buffer('hann_window', hann_window)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.mel_basis.device
|
||||
|
||||
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
|
||||
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
|
||||
|
||||
waveform = torch.nn.functional.pad(
|
||||
waveform.unsqueeze(1),
|
||||
[int((self.n_fft - self.hop_size) / 2),
|
||||
int((self.n_fft - self.hop_size) / 2)],
|
||||
mode='reflect')
|
||||
waveform = waveform.squeeze(1)
|
||||
|
||||
spec = torch.stft(waveform,
|
||||
self.n_fft,
|
||||
hop_length=self.hop_size,
|
||||
win_length=self.win_size,
|
||||
window=self.hann_window,
|
||||
center=center,
|
||||
pad_mode='reflect',
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True)
|
||||
|
||||
spec = torch.view_as_real(spec)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
||||
spec = torch.matmul(self.mel_basis, spec)
|
||||
spec = spectral_normalize_torch(spec, self.norm_fn)
|
||||
|
||||
return spec
|
||||
|
||||
class AudioAutoencoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# ckpt_path: str,
|
||||
mode=Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode == "16k", "Only 16k mode is supported currently."
|
||||
self.mel_converter = MelConverter(sampling_rate=16_000,
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=8_000,
|
||||
norm_fn=torch.log10)
|
||||
|
||||
self.vae = VAE_16k().eval()
|
||||
|
||||
bigvgan_config = {
|
||||
"resblock": "1",
|
||||
"num_mels": 80,
|
||||
"upsample_rates": [4, 4, 2, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [
|
||||
[1, 3, 5],
|
||||
[1, 3, 5],
|
||||
[1, 3, 5],
|
||||
],
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": True,
|
||||
}
|
||||
|
||||
self.vocoder = BigVGANVocoder(
|
||||
bigvgan_config
|
||||
).eval()
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_audio(self, x) -> DiagonalGaussianDistribution:
|
||||
# x: (B * L)
|
||||
mel = self.mel_converter(x)
|
||||
dist = self.vae.encode(mel)
|
||||
|
||||
return dist
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, z):
|
||||
mel_decoded = self.vae.decode(z)
|
||||
audio = self.vocoder(mel_decoded)
|
||||
|
||||
audio = torchaudio.functional.resample(audio, 16000, 44100)
|
||||
return audio
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audio):
|
||||
audio = audio.mean(dim=1)
|
||||
audio = torchaudio.functional.resample(audio, 44100, 16000)
|
||||
dist = self.encode_audio(audio)
|
||||
return dist.mean
|
||||
219
comfy/ldm/mmaudio/vae/bigvgan.py
Normal file
219
comfy/ldm/mmaudio/vae/bigvgan.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from types import SimpleNamespace
|
||||
from . import activations
|
||||
from .alias_free_torch import Activation1d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
||||
super(AMPBlock1, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]))
|
||||
])
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1))
|
||||
])
|
||||
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AMPBlock2(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
|
||||
super(AMPBlock2, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))
|
||||
])
|
||||
|
||||
self.num_layers = len(self.convs) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c, a in zip(self.convs, self.activations):
|
||||
xt = a(x)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BigVGANVocoder(torch.nn.Module):
|
||||
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
||||
def __init__(self, h):
|
||||
super().__init__()
|
||||
if isinstance(h, dict):
|
||||
h = SimpleNamespace(**h)
|
||||
self.h = h
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# pre conv
|
||||
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
|
||||
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
|
||||
|
||||
# transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
nn.ModuleList([
|
||||
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2)
|
||||
]))
|
||||
|
||||
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# post conv
|
||||
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# pre conv
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
# upsampling
|
||||
for i_up in range(len(self.ups[i])):
|
||||
x = self.ups[i][i_up](x)
|
||||
# AMP blocks
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
# post conv
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
92
comfy/ldm/mmaudio/vae/distributions.py
Normal file
92
comfy/ldm/mmaudio/vae/distributions.py
Normal file
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user