mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 02:57:24 +08:00
Merge branch 'master' into execution_model_inversion
This commit is contained in:
commit
9d624564fa
@ -7,7 +7,7 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "121"
|
default: "124"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
@ -19,7 +19,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "2"
|
default: "3"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
@ -49,7 +49,7 @@ jobs:
|
|||||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
./python.exe get-pip.py
|
./python.exe get-pip.py
|
||||||
python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||||
ls ../temp_wheel_dir
|
ls ../temp_wheel_dir
|
||||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||||
|
|||||||
84
README.md
84
README.md
@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) and [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||||
@ -41,29 +41,32 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
|
|
||||||
## Shortcuts
|
## Shortcuts
|
||||||
|
|
||||||
| Keybind | Explanation |
|
| Keybind | Explanation |
|
||||||
|---------------------------|--------------------------------------------------------------------------------------------------------------------|
|
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||||
| Ctrl + Enter | Queue up current graph for generation |
|
| Ctrl + Enter | Queue up current graph for generation |
|
||||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||||
| Ctrl + S | Save workflow |
|
| Ctrl + S | Save workflow |
|
||||||
| Ctrl + O | Load workflow |
|
| Ctrl + O | Load workflow |
|
||||||
| Ctrl + A | Select all nodes |
|
| Ctrl + A | Select all nodes |
|
||||||
| Alt + C | Collapse/uncollapse selected nodes |
|
| Alt + C | Collapse/uncollapse selected nodes |
|
||||||
| Ctrl + M | Mute/unmute selected nodes |
|
| Ctrl + M | Mute/unmute selected nodes |
|
||||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||||
| Delete/Backspace | Delete selected nodes |
|
| Delete/Backspace | Delete selected nodes |
|
||||||
| Ctrl + Delete/Backspace | Delete the current graph |
|
| Ctrl + Backspace | Delete the current graph |
|
||||||
| Space | Move the canvas around when held and moving the cursor |
|
| Space | Move the canvas around when held and moving the cursor |
|
||||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||||
| Ctrl + D | Load default graph |
|
| Ctrl + D | Load default graph |
|
||||||
| Q | Toggle visibility of the queue |
|
| Alt + `+` | Canvas Zoom in |
|
||||||
| H | Toggle visibility of history |
|
| Alt + `-` | Canvas Zoom out |
|
||||||
| R | Refresh graph |
|
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
||||||
| Double-Click LMB | Open node quick search palette |
|
| Q | Toggle visibility of the queue |
|
||||||
|
| H | Toggle visibility of history |
|
||||||
|
| R | Refresh graph |
|
||||||
|
| Double-Click LMB | Open node quick search palette |
|
||||||
|
|
||||||
Ctrl can also be replaced with Cmd instead for macOS users
|
Ctrl can also be replaced with Cmd instead for macOS users
|
||||||
|
|
||||||
@ -99,11 +102,11 @@ Put your VAE in: models/vae
|
|||||||
### AMD GPUs (Linux only)
|
### AMD GPUs (Linux only)
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
@ -113,7 +116,7 @@ Nvidia users should install stable pytorch using this command:
|
|||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
|
||||||
|
|
||||||
#### Troubleshooting
|
#### Troubleshooting
|
||||||
|
|
||||||
@ -133,7 +136,16 @@ After this you should have everything installed and can proceed to running Comfy
|
|||||||
|
|
||||||
### Others:
|
### Others:
|
||||||
|
|
||||||
#### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
|
#### Intel GPUs
|
||||||
|
|
||||||
|
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
|
||||||
|
|
||||||
|
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
|
||||||
|
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
|
||||||
|
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
|
||||||
|
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
|
||||||
|
|
||||||
|
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||||
|
|
||||||
#### Apple Mac silicon
|
#### Apple Mac silicon
|
||||||
|
|
||||||
@ -195,20 +207,20 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||||||
```embedding:embedding_filename.pt```
|
```embedding:embedding_filename.pt```
|
||||||
|
|
||||||
|
|
||||||
## How to increase generation speed?
|
|
||||||
|
|
||||||
Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU.
|
|
||||||
|
|
||||||
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
|
|
||||||
|
|
||||||
```--dont-upcast-attention```
|
|
||||||
|
|
||||||
## How to show high-quality previews?
|
## How to show high-quality previews?
|
||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||||
|
|
||||||
|
## How to use TLS/SSL?
|
||||||
|
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||||
|
|
||||||
|
Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`.
|
||||||
|
|
||||||
|
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||||
|
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||||
|
|
||||||
## Support and dev channel
|
## Support and dev channel
|
||||||
|
|
||||||
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
||||||
|
|||||||
@ -52,6 +52,7 @@ class ControlNet(nn.Module):
|
|||||||
adm_in_channels=None,
|
adm_in_channels=None,
|
||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
transformer_depth_output=None,
|
transformer_depth_output=None,
|
||||||
|
attn_precision=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops.disable_weight_init,
|
operations=comfy.ops.disable_weight_init,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -202,7 +203,7 @@ class ControlNet(nn.Module):
|
|||||||
SpatialTransformer(
|
SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
@ -262,7 +263,7 @@ class ControlNet(nn.Module):
|
|||||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
|
|||||||
@ -35,6 +35,8 @@ parser = argparse.ArgumentParser()
|
|||||||
|
|
||||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
|
parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
|
||||||
|
parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
|
||||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||||
|
|
||||||
@ -49,7 +51,6 @@ cm_group = parser.add_mutually_exclusive_group()
|
|||||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||||
|
|
||||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
|
||||||
|
|
||||||
fp_group = parser.add_mutually_exclusive_group()
|
fp_group = parser.add_mutually_exclusive_group()
|
||||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
@ -74,6 +75,7 @@ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store
|
|||||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||||
|
|
||||||
|
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
@ -98,6 +100,11 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he
|
|||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
|
upcast = parser.add_mutually_exclusive_group()
|
||||||
|
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
|
||||||
|
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||||
|
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
|
|||||||
@ -29,7 +29,12 @@ class CONDRegular:
|
|||||||
|
|
||||||
class CONDNoiseShape(CONDRegular):
|
class CONDNoiseShape(CONDRegular):
|
||||||
def process_cond(self, batch_size, device, area, **kwargs):
|
def process_cond(self, batch_size, device, area, **kwargs):
|
||||||
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
data = self.cond
|
||||||
|
if area is not None:
|
||||||
|
dims = len(area) // 2
|
||||||
|
for i in range(dims):
|
||||||
|
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||||
|
|
||||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -129,8 +129,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
if s_churn > 0:
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
|
else:
|
||||||
|
gamma = 0
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
eps = torch.randn_like(x) * s_noise
|
eps = torch.randn_like(x) * s_noise
|
||||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||||
@ -170,7 +175,13 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
if s_churn > 0:
|
||||||
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
|
else:
|
||||||
|
gamma = 0
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
eps = torch.randn_like(x) * s_noise
|
eps = torch.randn_like(x) * s_noise
|
||||||
@ -199,8 +210,13 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
if s_churn > 0:
|
||||||
sigma_hat = sigmas[i] * (gamma + 1)
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
|
sigma_hat = sigmas[i] * (gamma + 1)
|
||||||
|
else:
|
||||||
|
gamma = 0
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
eps = torch.randn_like(x) * s_noise
|
eps = torch.randn_like(x) * s_noise
|
||||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||||
@ -527,6 +543,9 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
"""DPM-Solver++ (stochastic)."""
|
"""DPM-Solver++ (stochastic)."""
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
@ -595,6 +614,8 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
"""DPM-Solver++(2M) SDE."""
|
"""DPM-Solver++(2M) SDE."""
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
if solver_type not in {'heun', 'midpoint'}:
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
@ -642,6 +663,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""DPM-Solver++(3M) SDE."""
|
"""DPM-Solver++(3M) SDE."""
|
||||||
|
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
@ -690,18 +714,27 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import torch
|
|||||||
|
|
||||||
class LatentFormat:
|
class LatentFormat:
|
||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
|
latent_channels = 4
|
||||||
latent_rgb_factors = None
|
latent_rgb_factors = None
|
||||||
taesd_decoder_name = None
|
taesd_decoder_name = None
|
||||||
|
|
||||||
@ -24,8 +25,9 @@ class SD15(LatentFormat):
|
|||||||
self.taesd_decoder_name = "taesd_decoder"
|
self.taesd_decoder_name = "taesd_decoder"
|
||||||
|
|
||||||
class SDXL(LatentFormat):
|
class SDXL(LatentFormat):
|
||||||
|
scale_factor = 0.13025
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 0.13025
|
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
# R G B
|
# R G B
|
||||||
[ 0.3920, 0.4054, 0.4549],
|
[ 0.3920, 0.4054, 0.4549],
|
||||||
@ -72,6 +74,7 @@ class SD_X4(LatentFormat):
|
|||||||
]
|
]
|
||||||
|
|
||||||
class SC_Prior(LatentFormat):
|
class SC_Prior(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 1.0
|
self.scale_factor = 1.0
|
||||||
self.latent_rgb_factors = [
|
self.latent_rgb_factors = [
|
||||||
@ -102,3 +105,37 @@ class SC_B(LatentFormat):
|
|||||||
[-0.3087, -0.1535, 0.0366],
|
[-0.3087, -0.1535, 0.0366],
|
||||||
[ 0.0290, -0.1574, -0.4078]
|
[ 0.0290, -0.1574, -0.4078]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class SD3(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 1.5305
|
||||||
|
self.shift_factor = 0.0609
|
||||||
|
self.latent_rgb_factors = [
|
||||||
|
[-0.0645, 0.0177, 0.1052],
|
||||||
|
[ 0.0028, 0.0312, 0.0650],
|
||||||
|
[ 0.1848, 0.0762, 0.0360],
|
||||||
|
[ 0.0944, 0.0360, 0.0889],
|
||||||
|
[ 0.0897, 0.0506, -0.0364],
|
||||||
|
[-0.0020, 0.1203, 0.0284],
|
||||||
|
[ 0.0855, 0.0118, 0.0283],
|
||||||
|
[-0.0539, 0.0658, 0.1047],
|
||||||
|
[-0.0057, 0.0116, 0.0700],
|
||||||
|
[-0.0412, 0.0281, -0.0039],
|
||||||
|
[ 0.1106, 0.1171, 0.1220],
|
||||||
|
[-0.0248, 0.0682, -0.0481],
|
||||||
|
[ 0.0815, 0.0846, 0.1207],
|
||||||
|
[-0.0120, -0.0055, -0.0867],
|
||||||
|
[-0.0749, -0.0634, -0.0456],
|
||||||
|
[-0.1418, -0.1457, -0.1259]
|
||||||
|
]
|
||||||
|
self.taesd_decoder_name = "taesd3_decoder"
|
||||||
|
|
||||||
|
def process_in(self, latent):
|
||||||
|
return (latent - self.shift_factor) * self.scale_factor
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
return (latent / self.scale_factor) + self.shift_factor
|
||||||
|
|
||||||
|
class StableAudio1(LatentFormat):
|
||||||
|
latent_channels = 64
|
||||||
|
|||||||
282
comfy/ldm/audio/autoencoder.py
Normal file
282
comfy/ldm/audio/autoencoder.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from typing import Literal, Dict, Any
|
||||||
|
import math
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def vae_sample(mean, scale):
|
||||||
|
stdev = nn.functional.softplus(scale) + 1e-4
|
||||||
|
var = stdev * stdev
|
||||||
|
logvar = torch.log(var)
|
||||||
|
latents = torch.randn_like(mean) * stdev + mean
|
||||||
|
|
||||||
|
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||||
|
|
||||||
|
return latents, kl
|
||||||
|
|
||||||
|
class VAEBottleneck(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.is_discrete = False
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
mean, scale = x.chunk(2, dim=1)
|
||||||
|
|
||||||
|
x, kl = vae_sample(mean, scale)
|
||||||
|
|
||||||
|
info["kl"] = kl
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def snake_beta(x, alpha, beta):
|
||||||
|
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||||
|
|
||||||
|
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||||
|
super(SnakeBeta, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||||
|
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # linear scale alphas initialized to ones
|
||||||
|
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||||
|
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
# self.alpha.requires_grad = alpha_trainable
|
||||||
|
# self.beta.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
|
||||||
|
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
beta = torch.exp(beta)
|
||||||
|
x = snake_beta(x, alpha, beta)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def WNConv1d(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||||
|
except:
|
||||||
|
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||||
|
|
||||||
|
def WNConvTranspose1d(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||||
|
except:
|
||||||
|
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||||
|
|
||||||
|
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||||
|
if activation == "elu":
|
||||||
|
act = torch.nn.ELU()
|
||||||
|
elif activation == "snake":
|
||||||
|
act = SnakeBeta(channels)
|
||||||
|
elif activation == "none":
|
||||||
|
act = torch.nn.Identity()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation {activation}")
|
||||||
|
|
||||||
|
if antialias:
|
||||||
|
act = Activation1d(act)
|
||||||
|
|
||||||
|
return act
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualUnit(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
padding = (dilation * (7-1)) // 2
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=7, dilation=dilation, padding=padding),
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||||
|
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
kernel_size=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = x
|
||||||
|
|
||||||
|
#x = checkpoint(self.layers, x)
|
||||||
|
x = self.layers(x)
|
||||||
|
|
||||||
|
return x + res
|
||||||
|
|
||||||
|
class EncoderBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if use_nearest_upsample:
|
||||||
|
upsample_layer = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||||
|
WNConv1d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2*stride,
|
||||||
|
stride=1,
|
||||||
|
bias=False,
|
||||||
|
padding='same')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||||
|
upsample_layer,
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=1, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=3, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=9, use_snake=use_snake),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
class OobleckEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=32,
|
||||||
|
c_mults = [1, 2, 4, 8],
|
||||||
|
strides = [2, 4, 8, 8],
|
||||||
|
use_snake=False,
|
||||||
|
antialias_activation=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
c_mults = [1] + c_mults
|
||||||
|
|
||||||
|
self.depth = len(c_mults)
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(self.depth-1):
|
||||||
|
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||||
|
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckDecoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
out_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=32,
|
||||||
|
c_mults = [1, 2, 4, 8],
|
||||||
|
strides = [2, 4, 8, 8],
|
||||||
|
use_snake=False,
|
||||||
|
antialias_activation=False,
|
||||||
|
use_nearest_upsample=False,
|
||||||
|
final_tanh=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
c_mults = [1] + c_mults
|
||||||
|
|
||||||
|
self.depth = len(c_mults)
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(self.depth-1, 0, -1):
|
||||||
|
layers += [DecoderBlock(
|
||||||
|
in_channels=c_mults[i]*channels,
|
||||||
|
out_channels=c_mults[i-1]*channels,
|
||||||
|
stride=strides[i-1],
|
||||||
|
use_snake=use_snake,
|
||||||
|
antialias_activation=antialias_activation,
|
||||||
|
use_nearest_upsample=use_nearest_upsample
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||||
|
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||||
|
nn.Tanh() if final_tanh else nn.Identity()
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioOobleckVAE(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=64,
|
||||||
|
c_mults = [1, 2, 4, 8, 16],
|
||||||
|
strides = [2, 4, 4, 8, 8],
|
||||||
|
use_snake=True,
|
||||||
|
antialias_activation=False,
|
||||||
|
use_nearest_upsample=False,
|
||||||
|
final_tanh=False):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
|
||||||
|
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
|
||||||
|
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
|
||||||
|
self.bottleneck = VAEBottleneck()
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.bottleneck.encode(self.encoder(x))
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return self.decoder(self.bottleneck.decode(x))
|
||||||
|
|
||||||
888
comfy/ldm/audio/dit.py
Normal file
888
comfy/ldm/audio/dit.py
Normal file
@ -0,0 +1,888 @@
|
|||||||
|
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
assert out_features % 2 == 0
|
||||||
|
self.weight = nn.Parameter(torch.empty(
|
||||||
|
[out_features // 2, in_features], dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
|
||||||
|
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||||
|
|
||||||
|
# norms
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
|
||||||
|
"""
|
||||||
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
||||||
|
else:
|
||||||
|
self.beta = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
beta = self.beta
|
||||||
|
if self.beta is not None:
|
||||||
|
beta = beta.to(dtype=x.dtype, device=x.device)
|
||||||
|
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
activation,
|
||||||
|
use_conv = False,
|
||||||
|
conv_kernel_size = 3,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.act = activation
|
||||||
|
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
|
||||||
|
self.use_conv = use_conv
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_conv:
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.proj(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
else:
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
x, gate = x.chunk(2, dim = -1)
|
||||||
|
return x * self.act(gate)
|
||||||
|
|
||||||
|
class AbsolutePositionalEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, max_seq_len):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** -0.5
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.emb = nn.Embedding(max_seq_len, dim)
|
||||||
|
|
||||||
|
def forward(self, x, pos = None, seq_start_pos = None):
|
||||||
|
seq_len, device = x.shape[1], x.device
|
||||||
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
||||||
|
|
||||||
|
if pos is None:
|
||||||
|
pos = torch.arange(seq_len, device = device)
|
||||||
|
|
||||||
|
if seq_start_pos is not None:
|
||||||
|
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||||
|
|
||||||
|
pos_emb = self.emb(pos)
|
||||||
|
pos_emb = pos_emb * self.scale
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
class ScaledSinusoidalEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, theta = 10000):
|
||||||
|
super().__init__()
|
||||||
|
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
||||||
|
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
||||||
|
|
||||||
|
half_dim = dim // 2
|
||||||
|
freq_seq = torch.arange(half_dim).float() / half_dim
|
||||||
|
inv_freq = theta ** -freq_seq
|
||||||
|
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
||||||
|
|
||||||
|
def forward(self, x, pos = None, seq_start_pos = None):
|
||||||
|
seq_len, device = x.shape[1], x.device
|
||||||
|
|
||||||
|
if pos is None:
|
||||||
|
pos = torch.arange(seq_len, device = device)
|
||||||
|
|
||||||
|
if seq_start_pos is not None:
|
||||||
|
pos = pos - seq_start_pos[..., None]
|
||||||
|
|
||||||
|
emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||||
|
return emb * self.scale
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
use_xpos = False,
|
||||||
|
scale_base = 512,
|
||||||
|
interpolation_factor = 1.,
|
||||||
|
base = 10000,
|
||||||
|
base_rescale_factor = 1.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
|
# has some connection to NTK literature
|
||||||
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||||
|
|
||||||
|
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
self.register_buffer('inv_freq', inv_freq)
|
||||||
|
|
||||||
|
assert interpolation_factor >= 1.
|
||||||
|
self.interpolation_factor = interpolation_factor
|
||||||
|
|
||||||
|
if not use_xpos:
|
||||||
|
self.register_buffer('scale', None)
|
||||||
|
return
|
||||||
|
|
||||||
|
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||||
|
|
||||||
|
self.scale_base = scale_base
|
||||||
|
self.register_buffer('scale', scale)
|
||||||
|
|
||||||
|
def forward_from_seq_len(self, seq_len, device, dtype):
|
||||||
|
# device = self.inv_freq.device
|
||||||
|
|
||||||
|
t = torch.arange(seq_len, device=device, dtype=dtype)
|
||||||
|
return self.forward(t)
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
# device = self.inv_freq.device
|
||||||
|
device = t.device
|
||||||
|
dtype = t.dtype
|
||||||
|
|
||||||
|
# t = t.to(torch.float32)
|
||||||
|
|
||||||
|
t = t / self.interpolation_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
|
||||||
|
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||||
|
|
||||||
|
if self.scale is None:
|
||||||
|
return freqs, 1.
|
||||||
|
|
||||||
|
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||||
|
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
|
||||||
|
scale = torch.cat((scale, scale), dim = -1)
|
||||||
|
|
||||||
|
return freqs, scale
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
||||||
|
x1, x2 = x.unbind(dim = -2)
|
||||||
|
return torch.cat((-x2, x1), dim = -1)
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
||||||
|
out_dtype = t.dtype
|
||||||
|
|
||||||
|
# cast to float32 if necessary for numerical stability
|
||||||
|
dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
||||||
|
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
||||||
|
freqs, t = freqs.to(dtype), t.to(dtype)
|
||||||
|
freqs = freqs[-seq_len:, :]
|
||||||
|
|
||||||
|
if t.ndim == 4 and freqs.ndim == 3:
|
||||||
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
||||||
|
|
||||||
|
# partial rotary embeddings, Wang et al. GPT-J
|
||||||
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
||||||
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||||
|
|
||||||
|
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
||||||
|
|
||||||
|
return torch.cat((t, t_unrotated), dim = -1)
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_out = None,
|
||||||
|
mult = 4,
|
||||||
|
no_bias = False,
|
||||||
|
glu = True,
|
||||||
|
use_conv = False,
|
||||||
|
conv_kernel_size = 3,
|
||||||
|
zero_init_output = True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
|
||||||
|
# Default to SwiGLU
|
||||||
|
|
||||||
|
activation = nn.SiLU()
|
||||||
|
|
||||||
|
dim_out = dim if dim_out is None else dim_out
|
||||||
|
|
||||||
|
if glu:
|
||||||
|
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
linear_in = nn.Sequential(
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
activation
|
||||||
|
)
|
||||||
|
|
||||||
|
linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# # init last linear layer to 0
|
||||||
|
# if zero_init_output:
|
||||||
|
# nn.init.zeros_(linear_out.weight)
|
||||||
|
# if not no_bias:
|
||||||
|
# nn.init.zeros_(linear_out.bias)
|
||||||
|
|
||||||
|
|
||||||
|
self.ff = nn.Sequential(
|
||||||
|
linear_in,
|
||||||
|
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||||
|
linear_out,
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ff(x)
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_heads = 64,
|
||||||
|
dim_context = None,
|
||||||
|
causal = False,
|
||||||
|
zero_init_output=True,
|
||||||
|
qk_norm = False,
|
||||||
|
natten_kernel_size = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_heads = dim_heads
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
dim_kv = dim_context if dim_context is not None else dim
|
||||||
|
|
||||||
|
self.num_heads = dim // dim_heads
|
||||||
|
self.kv_heads = dim_kv // dim_heads
|
||||||
|
|
||||||
|
if dim_context is not None:
|
||||||
|
self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# if zero_init_output:
|
||||||
|
# nn.init.zeros_(self.to_out.weight)
|
||||||
|
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context = None,
|
||||||
|
mask = None,
|
||||||
|
context_mask = None,
|
||||||
|
rotary_pos_emb = None,
|
||||||
|
causal = None
|
||||||
|
):
|
||||||
|
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||||
|
|
||||||
|
kv_input = context if has_context else x
|
||||||
|
|
||||||
|
if hasattr(self, 'to_q'):
|
||||||
|
# Use separate linear projections for q and k/v
|
||||||
|
q = self.to_q(x)
|
||||||
|
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
||||||
|
|
||||||
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
||||||
|
else:
|
||||||
|
# Use fused linear projection
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||||
|
|
||||||
|
# Normalize q and k for cosine sim attention
|
||||||
|
if self.qk_norm:
|
||||||
|
q = F.normalize(q, dim=-1)
|
||||||
|
k = F.normalize(k, dim=-1)
|
||||||
|
|
||||||
|
if rotary_pos_emb is not None and not has_context:
|
||||||
|
freqs, _ = rotary_pos_emb
|
||||||
|
|
||||||
|
q_dtype = q.dtype
|
||||||
|
k_dtype = k.dtype
|
||||||
|
|
||||||
|
q = q.to(torch.float32)
|
||||||
|
k = k.to(torch.float32)
|
||||||
|
freqs = freqs.to(torch.float32)
|
||||||
|
|
||||||
|
q = apply_rotary_pos_emb(q, freqs)
|
||||||
|
k = apply_rotary_pos_emb(k, freqs)
|
||||||
|
|
||||||
|
q = q.to(q_dtype)
|
||||||
|
k = k.to(k_dtype)
|
||||||
|
|
||||||
|
input_mask = context_mask
|
||||||
|
|
||||||
|
if input_mask is None and not has_context:
|
||||||
|
input_mask = mask
|
||||||
|
|
||||||
|
# determine masking
|
||||||
|
masks = []
|
||||||
|
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
||||||
|
|
||||||
|
if input_mask is not None:
|
||||||
|
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
||||||
|
masks.append(~input_mask)
|
||||||
|
|
||||||
|
# Other masks will be added here later
|
||||||
|
|
||||||
|
if len(masks) > 0:
|
||||||
|
final_attn_mask = ~or_reduce(masks)
|
||||||
|
|
||||||
|
n, device = q.shape[-2], q.device
|
||||||
|
|
||||||
|
causal = self.causal if causal is None else causal
|
||||||
|
|
||||||
|
if n == 1 and causal:
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
if h != kv_h:
|
||||||
|
# Repeat interleave kv_heads to match q_heads
|
||||||
|
heads_per_kv_head = h // kv_h
|
||||||
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||||
|
|
||||||
|
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
||||||
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = rearrange(mask, 'b n -> b n 1')
|
||||||
|
out = out.masked_fill(~mask, 0.)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class ConformerModule(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
norm_kwargs = {},
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
||||||
|
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||||
|
self.glu = GLU(dim, dim, nn.SiLU())
|
||||||
|
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
||||||
|
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
||||||
|
self.swish = nn.SiLU()
|
||||||
|
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.in_norm(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.pointwise_conv(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
x = self.glu(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
x = self.mid_norm(x)
|
||||||
|
x = self.swish(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.pointwise_conv_2(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_heads = 64,
|
||||||
|
cross_attend = False,
|
||||||
|
dim_context = None,
|
||||||
|
global_cond_dim = None,
|
||||||
|
causal = False,
|
||||||
|
zero_init_branch_outputs = True,
|
||||||
|
conformer = False,
|
||||||
|
layer_ix = -1,
|
||||||
|
remove_norms = False,
|
||||||
|
attn_kwargs = {},
|
||||||
|
ff_kwargs = {},
|
||||||
|
norm_kwargs = {},
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_heads = dim_heads
|
||||||
|
self.cross_attend = cross_attend
|
||||||
|
self.dim_context = dim_context
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||||
|
|
||||||
|
self.self_attn = Attention(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_output=zero_init_branch_outputs,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**attn_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if cross_attend:
|
||||||
|
self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||||
|
self.cross_attn = Attention(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
dim_context=dim_context,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_output=zero_init_branch_outputs,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**attn_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||||
|
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
|
||||||
|
|
||||||
|
self.layer_ix = layer_ix
|
||||||
|
|
||||||
|
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
|
||||||
|
|
||||||
|
self.global_cond_dim = global_cond_dim
|
||||||
|
|
||||||
|
if global_cond_dim is not None:
|
||||||
|
self.to_scale_shift_gate = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(global_cond_dim, dim * 6, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
|
||||||
|
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context = None,
|
||||||
|
global_cond=None,
|
||||||
|
mask = None,
|
||||||
|
context_mask = None,
|
||||||
|
rotary_pos_emb = None
|
||||||
|
):
|
||||||
|
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||||
|
|
||||||
|
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
|
||||||
|
|
||||||
|
# self-attention with adaLN
|
||||||
|
residual = x
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
x = x * (1 + scale_self) + shift_self
|
||||||
|
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||||
|
x = x * torch.sigmoid(1 - gate_self)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
if context is not None:
|
||||||
|
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||||
|
|
||||||
|
if self.conformer is not None:
|
||||||
|
x = x + self.conformer(x)
|
||||||
|
|
||||||
|
# feedforward with adaLN
|
||||||
|
residual = x
|
||||||
|
x = self.ff_norm(x)
|
||||||
|
x = x * (1 + scale_ff) + shift_ff
|
||||||
|
x = self.ff(x)
|
||||||
|
x = x * torch.sigmoid(1 - gate_ff)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
else:
|
||||||
|
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||||
|
|
||||||
|
if context is not None:
|
||||||
|
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||||
|
|
||||||
|
if self.conformer is not None:
|
||||||
|
x = x + self.conformer(x)
|
||||||
|
|
||||||
|
x = x + self.ff(self.ff_norm(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ContinuousTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
*,
|
||||||
|
dim_in = None,
|
||||||
|
dim_out = None,
|
||||||
|
dim_heads = 64,
|
||||||
|
cross_attend=False,
|
||||||
|
cond_token_dim=None,
|
||||||
|
global_cond_dim=None,
|
||||||
|
causal=False,
|
||||||
|
rotary_pos_emb=True,
|
||||||
|
zero_init_branch_outputs=True,
|
||||||
|
conformer=False,
|
||||||
|
use_sinusoidal_emb=False,
|
||||||
|
use_abs_pos_emb=False,
|
||||||
|
abs_pos_emb_max_length=10000,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.depth = depth
|
||||||
|
self.causal = causal
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
|
||||||
|
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
||||||
|
|
||||||
|
if rotary_pos_emb:
|
||||||
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||||
|
else:
|
||||||
|
self.rotary_pos_emb = None
|
||||||
|
|
||||||
|
self.use_sinusoidal_emb = use_sinusoidal_emb
|
||||||
|
if use_sinusoidal_emb:
|
||||||
|
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
||||||
|
|
||||||
|
self.use_abs_pos_emb = use_abs_pos_emb
|
||||||
|
if use_abs_pos_emb:
|
||||||
|
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
|
||||||
|
|
||||||
|
for i in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
TransformerBlock(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
cross_attend = cross_attend,
|
||||||
|
dim_context = cond_token_dim,
|
||||||
|
global_cond_dim = global_cond_dim,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_branch_outputs = zero_init_branch_outputs,
|
||||||
|
conformer=conformer,
|
||||||
|
layer_ix=i,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
mask = None,
|
||||||
|
prepend_embeds = None,
|
||||||
|
prepend_mask = None,
|
||||||
|
global_cond = None,
|
||||||
|
return_info = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
batch, seq, device = *x.shape[:2], x.device
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"hidden_states": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
x = self.project_in(x)
|
||||||
|
|
||||||
|
if prepend_embeds is not None:
|
||||||
|
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
||||||
|
|
||||||
|
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
||||||
|
|
||||||
|
x = torch.cat((prepend_embeds, x), dim = -2)
|
||||||
|
|
||||||
|
if prepend_mask is not None or mask is not None:
|
||||||
|
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
|
||||||
|
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
|
||||||
|
|
||||||
|
mask = torch.cat((prepend_mask, mask), dim = -1)
|
||||||
|
|
||||||
|
# Attention layers
|
||||||
|
|
||||||
|
if self.rotary_pos_emb is not None:
|
||||||
|
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
||||||
|
else:
|
||||||
|
rotary_pos_emb = None
|
||||||
|
|
||||||
|
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||||
|
x = x + self.pos_emb(x)
|
||||||
|
|
||||||
|
# Iterate over the transformer layers
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||||
|
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
info["hidden_states"].append(x)
|
||||||
|
|
||||||
|
x = self.project_out(x)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class AudioDiffusionTransformer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
io_channels=64,
|
||||||
|
patch_size=1,
|
||||||
|
embed_dim=1536,
|
||||||
|
cond_token_dim=768,
|
||||||
|
project_cond_tokens=False,
|
||||||
|
global_cond_dim=1536,
|
||||||
|
project_global_cond=True,
|
||||||
|
input_concat_dim=0,
|
||||||
|
prepend_cond_dim=0,
|
||||||
|
depth=24,
|
||||||
|
num_heads=24,
|
||||||
|
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
||||||
|
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
||||||
|
audio_model="",
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.cond_token_dim = cond_token_dim
|
||||||
|
|
||||||
|
# Timestep embeddings
|
||||||
|
timestep_features_dim = 256
|
||||||
|
|
||||||
|
self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.to_timestep_embed = nn.Sequential(
|
||||||
|
operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
if cond_token_dim > 0:
|
||||||
|
# Conditioning tokens
|
||||||
|
|
||||||
|
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
||||||
|
self.to_cond_embed = nn.Sequential(
|
||||||
|
operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cond_embed_dim = 0
|
||||||
|
|
||||||
|
if global_cond_dim > 0:
|
||||||
|
# Global conditioning
|
||||||
|
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
||||||
|
self.to_global_embed = nn.Sequential(
|
||||||
|
operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if prepend_cond_dim > 0:
|
||||||
|
# Prepend conditioning
|
||||||
|
self.to_prepend_embed = nn.Sequential(
|
||||||
|
operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_concat_dim = input_concat_dim
|
||||||
|
|
||||||
|
dim_in = io_channels + self.input_concat_dim
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
# Transformer
|
||||||
|
|
||||||
|
self.transformer_type = transformer_type
|
||||||
|
|
||||||
|
self.global_cond_type = global_cond_type
|
||||||
|
|
||||||
|
if self.transformer_type == "continuous_transformer":
|
||||||
|
|
||||||
|
global_dim = None
|
||||||
|
|
||||||
|
if self.global_cond_type == "adaLN":
|
||||||
|
# The global conditioning is projected to the embed_dim already at this point
|
||||||
|
global_dim = embed_dim
|
||||||
|
|
||||||
|
self.transformer = ContinuousTransformer(
|
||||||
|
dim=embed_dim,
|
||||||
|
depth=depth,
|
||||||
|
dim_heads=embed_dim // num_heads,
|
||||||
|
dim_in=dim_in * patch_size,
|
||||||
|
dim_out=io_channels * patch_size,
|
||||||
|
cross_attend = cond_token_dim > 0,
|
||||||
|
cond_token_dim = cond_embed_dim,
|
||||||
|
global_cond_dim=global_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
||||||
|
|
||||||
|
self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
|
||||||
|
self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
mask=None,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_cond_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_embed=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
return_info=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
||||||
|
|
||||||
|
if global_embed is not None:
|
||||||
|
# Project the global conditioning to the embedding dimension
|
||||||
|
global_embed = self.to_global_embed(global_embed)
|
||||||
|
|
||||||
|
prepend_inputs = None
|
||||||
|
prepend_mask = None
|
||||||
|
prepend_length = 0
|
||||||
|
if prepend_cond is not None:
|
||||||
|
# Project the prepend conditioning to the embedding dimension
|
||||||
|
prepend_cond = self.to_prepend_embed(prepend_cond)
|
||||||
|
|
||||||
|
prepend_inputs = prepend_cond
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
prepend_mask = prepend_cond_mask
|
||||||
|
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
|
||||||
|
# Interpolate input_concat_cond to the same length as x
|
||||||
|
if input_concat_cond.shape[2] != x.shape[2]:
|
||||||
|
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||||
|
|
||||||
|
x = torch.cat([x, input_concat_cond], dim=1)
|
||||||
|
|
||||||
|
# Get the batch of timestep embeddings
|
||||||
|
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
|
||||||
|
|
||||||
|
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||||
|
if global_embed is not None:
|
||||||
|
global_embed = global_embed + timestep_embed
|
||||||
|
else:
|
||||||
|
global_embed = timestep_embed
|
||||||
|
|
||||||
|
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
||||||
|
if self.global_cond_type == "prepend":
|
||||||
|
if prepend_inputs is None:
|
||||||
|
# Prepend inputs are just the global embed, and the mask is all ones
|
||||||
|
prepend_inputs = global_embed.unsqueeze(1)
|
||||||
|
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
# Prepend inputs are the prepend conditioning + the global embed
|
||||||
|
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
||||||
|
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
||||||
|
|
||||||
|
prepend_length = prepend_inputs.shape[1]
|
||||||
|
|
||||||
|
x = self.preprocess_conv(x) + x
|
||||||
|
|
||||||
|
x = rearrange(x, "b c t -> b t c")
|
||||||
|
|
||||||
|
extra_args = {}
|
||||||
|
|
||||||
|
if self.global_cond_type == "adaLN":
|
||||||
|
extra_args["global_cond"] = global_embed
|
||||||
|
|
||||||
|
if self.patch_size > 1:
|
||||||
|
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
||||||
|
|
||||||
|
if self.transformer_type == "x-transformers":
|
||||||
|
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
|
||||||
|
elif self.transformer_type == "continuous_transformer":
|
||||||
|
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
output, info = output
|
||||||
|
elif self.transformer_type == "mm_transformer":
|
||||||
|
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
|
||||||
|
|
||||||
|
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
||||||
|
|
||||||
|
if self.patch_size > 1:
|
||||||
|
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
||||||
|
|
||||||
|
output = self.postprocess_conv(output) + output
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output, info
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context=None,
|
||||||
|
context_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_embed=None,
|
||||||
|
negative_global_embed=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
mask=None,
|
||||||
|
return_info=False,
|
||||||
|
control=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs):
|
||||||
|
return self._forward(
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
cross_attn_cond=context,
|
||||||
|
cross_attn_cond_mask=context_mask,
|
||||||
|
input_concat_cond=input_concat_cond,
|
||||||
|
global_embed=global_embed,
|
||||||
|
prepend_cond=prepend_cond,
|
||||||
|
prepend_cond_mask=prepend_cond_mask,
|
||||||
|
mask=mask,
|
||||||
|
return_info=return_info,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
108
comfy/ldm/audio/embedders.py
Normal file
108
comfy/ldm/audio/embedders.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor, einsum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||||
|
from einops import rearrange
|
||||||
|
import math
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
class LearnedPositionalEmbedding(nn.Module):
|
||||||
|
"""Used for continuous time"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
assert (dim % 2) == 0
|
||||||
|
half_dim = dim // 2
|
||||||
|
self.weights = nn.Parameter(torch.empty(half_dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = rearrange(x, "b -> b 1")
|
||||||
|
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
|
||||||
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
||||||
|
fouriered = torch.cat((x, fouriered), dim=-1)
|
||||||
|
return fouriered
|
||||||
|
|
||||||
|
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
||||||
|
return nn.Sequential(
|
||||||
|
LearnedPositionalEmbedding(dim),
|
||||||
|
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NumberEmbedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
features: int,
|
||||||
|
dim: int = 256,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.features = features
|
||||||
|
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
|
||||||
|
|
||||||
|
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
|
||||||
|
if not torch.is_tensor(x):
|
||||||
|
device = next(self.embedding.parameters()).device
|
||||||
|
x = torch.tensor(x, device=device)
|
||||||
|
assert isinstance(x, Tensor)
|
||||||
|
shape = x.shape
|
||||||
|
x = rearrange(x, "... -> (...)")
|
||||||
|
embedding = self.embedding(x)
|
||||||
|
x = embedding.view(*shape, self.features)
|
||||||
|
return x # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Conditioner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
project_out: bool = False
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class NumberConditioner(Conditioner):
|
||||||
|
'''
|
||||||
|
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
output_dim: int,
|
||||||
|
min_val: float=0,
|
||||||
|
max_val: float=1
|
||||||
|
):
|
||||||
|
super().__init__(output_dim, output_dim)
|
||||||
|
|
||||||
|
self.min_val = min_val
|
||||||
|
self.max_val = max_val
|
||||||
|
|
||||||
|
self.embedder = NumberEmbedder(features=output_dim)
|
||||||
|
|
||||||
|
def forward(self, floats, device=None):
|
||||||
|
# Cast the inputs to floats
|
||||||
|
floats = [float(x) for x in floats]
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = next(self.embedder.parameters()).device
|
||||||
|
|
||||||
|
floats = torch.tensor(floats).to(device)
|
||||||
|
|
||||||
|
floats = floats.clamp(self.min_val, self.max_val)
|
||||||
|
|
||||||
|
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
||||||
|
|
||||||
|
# Cast floats to same type as embedder
|
||||||
|
embedder_dtype = next(self.embedder.parameters()).dtype
|
||||||
|
normalized_floats = normalized_floats.to(embedder_dtype)
|
||||||
|
|
||||||
|
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
||||||
|
|
||||||
|
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
||||||
@ -17,7 +17,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||||
|
|||||||
@ -18,7 +18,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import numpy as np
|
|
||||||
import math
|
import math
|
||||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||||
# from .controlnet import ControlNetDeliverer
|
# from .controlnet import ControlNetDeliverer
|
||||||
|
|||||||
@ -1,6 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
# import pytorch_lightning as pl
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|||||||
@ -3,10 +3,10 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from typing import Optional, Any
|
from typing import Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
@ -19,13 +19,14 @@ from comfy.cli_args import args
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
# CrossAttn precision handling
|
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
||||||
if args.dont_upcast_attention:
|
|
||||||
logging.info("disabling upcasting of attention")
|
|
||||||
_ATTN_PRECISION = "fp16"
|
|
||||||
else:
|
|
||||||
_ATTN_PRECISION = "fp32"
|
|
||||||
|
|
||||||
|
def get_attn_precision(attn_precision):
|
||||||
|
if args.dont_upcast_attention:
|
||||||
|
return None
|
||||||
|
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
|
||||||
|
return FORCE_UPCAST_ATTENTION_DTYPE
|
||||||
|
return attn_precision
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
@ -85,23 +86,35 @@ class FeedForward(nn.Module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
b, _, dim_head = q.shape
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
dim_head //= heads
|
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
|
|
||||||
h = heads
|
h = heads
|
||||||
q, k, v = map(
|
if skip_reshape:
|
||||||
lambda t: t.unsqueeze(3)
|
q, k, v = map(
|
||||||
.reshape(b, -1, heads, dim_head)
|
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||||
.permute(0, 2, 1, 3)
|
(q, k, v),
|
||||||
.reshape(b * heads, -1, dim_head)
|
)
|
||||||
.contiguous(),
|
else:
|
||||||
(q, k, v),
|
q, k, v = map(
|
||||||
)
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, -1, heads, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * heads, -1, dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
# force cast to fp32 to avoid overflowing
|
# force cast to fp32 to avoid overflowing
|
||||||
if _ATTN_PRECISION =="fp32":
|
if attn_precision == torch.float32:
|
||||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||||
else:
|
else:
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
@ -135,18 +148,29 @@ def attention_basic(q, k, v, heads, mask=None):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
b, _, dim_head = query.shape
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
dim_head //= heads
|
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = query.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = query.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
|
||||||
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
|
||||||
|
|
||||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
if skip_reshape:
|
||||||
|
query = query.reshape(b * heads, -1, dim_head)
|
||||||
|
value = value.reshape(b * heads, -1, dim_head)
|
||||||
|
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
|
||||||
|
else:
|
||||||
|
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||||
|
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||||
|
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||||
|
|
||||||
|
|
||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
||||||
if upcast_attention:
|
if upcast_attention:
|
||||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||||
else:
|
else:
|
||||||
@ -195,29 +219,43 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
b, _, dim_head = q.shape
|
attn_precision = get_attn_precision(attn_precision)
|
||||||
dim_head //= heads
|
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
|
|
||||||
h = heads
|
h = heads
|
||||||
q, k, v = map(
|
if skip_reshape:
|
||||||
lambda t: t.unsqueeze(3)
|
q, k, v = map(
|
||||||
.reshape(b, -1, heads, dim_head)
|
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||||
.permute(0, 2, 1, 3)
|
(q, k, v),
|
||||||
.reshape(b * heads, -1, dim_head)
|
)
|
||||||
.contiguous(),
|
else:
|
||||||
(q, k, v),
|
q, k, v = map(
|
||||||
)
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, -1, heads, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * heads, -1, dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
if _ATTN_PRECISION =="fp32":
|
if attn_precision == torch.float32:
|
||||||
element_size = 4
|
element_size = 4
|
||||||
|
upcast = True
|
||||||
else:
|
else:
|
||||||
element_size = q.element_size()
|
element_size = q.element_size()
|
||||||
|
upcast = False
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||||
@ -251,7 +289,7 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
if _ATTN_PRECISION =="fp32":
|
if upcast:
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||||
else:
|
else:
|
||||||
@ -297,26 +335,41 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
BROKEN_XFORMERS = False
|
BROKEN_XFORMERS = False
|
||||||
try:
|
try:
|
||||||
x_vers = xformers.__version__
|
x_vers = xformers.__version__
|
||||||
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
|
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
||||||
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
|
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
b, _, dim_head = q.shape
|
if skip_reshape:
|
||||||
dim_head //= heads
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
|
||||||
|
disabled_xformers = False
|
||||||
|
|
||||||
if BROKEN_XFORMERS:
|
if BROKEN_XFORMERS:
|
||||||
if b * heads > 65535:
|
if b * heads > 65535:
|
||||||
return attention_pytorch(q, k, v, heads, mask)
|
disabled_xformers = True
|
||||||
|
|
||||||
q, k, v = map(
|
if not disabled_xformers:
|
||||||
lambda t: t.unsqueeze(3)
|
if torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||||
.reshape(b, -1, heads, dim_head)
|
disabled_xformers = True
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b * heads, -1, dim_head)
|
if disabled_xformers:
|
||||||
.contiguous(),
|
return attention_pytorch(q, k, v, heads, mask)
|
||||||
(q, k, v),
|
|
||||||
)
|
if skip_reshape:
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
pad = 8 - q.shape[1] % 8
|
pad = 8 - q.shape[1] % 8
|
||||||
@ -326,21 +379,30 @@ def attention_xformers(q, k, v, heads, mask=None):
|
|||||||
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
|
||||||
out = (
|
if skip_reshape:
|
||||||
out.unsqueeze(0)
|
out = (
|
||||||
.reshape(b, heads, -1, dim_head)
|
out.unsqueeze(0)
|
||||||
.permute(0, 2, 1, 3)
|
.reshape(b, heads, -1, dim_head)
|
||||||
.reshape(b, -1, heads * dim_head)
|
.permute(0, 2, 1, 3)
|
||||||
)
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = (
|
||||||
|
out.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
b, _, dim_head = q.shape
|
if skip_reshape:
|
||||||
dim_head //= heads
|
b, _, _, dim_head = q.shape
|
||||||
q, k, v = map(
|
else:
|
||||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
b, _, dim_head = q.shape
|
||||||
(q, k, v),
|
dim_head //= heads
|
||||||
)
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
out = (
|
out = (
|
||||||
@ -384,10 +446,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
|
self.attn_precision = attn_precision
|
||||||
|
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
@ -409,15 +472,15 @@ class CrossAttention(nn.Module):
|
|||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = optimized_attention(q, k, v, self.heads)
|
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||||
else:
|
else:
|
||||||
out = optimized_attention_masked(q, k, v, self.heads, mask)
|
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
||||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
|
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.ff_in = ff_in or inner_dim is not None
|
self.ff_in = ff_in or inner_dim is not None
|
||||||
@ -425,6 +488,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
inner_dim = dim
|
inner_dim = dim
|
||||||
|
|
||||||
self.is_res = inner_dim == dim
|
self.is_res = inner_dim == dim
|
||||||
|
self.attn_precision = attn_precision
|
||||||
|
|
||||||
if self.ff_in:
|
if self.ff_in:
|
||||||
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
|
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
@ -432,7 +496,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
if disable_temporal_crossattention:
|
if disable_temporal_crossattention:
|
||||||
@ -446,20 +510,16 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
context_dim_attn2 = context_dim
|
context_dim_attn2 = context_dim
|
||||||
|
|
||||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||||
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.checkpoint = checkpoint
|
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
|
||||||
|
|
||||||
def _forward(self, x, context=None, transformer_options={}):
|
|
||||||
extra_options = {}
|
extra_options = {}
|
||||||
block = transformer_options.get("block", None)
|
block = transformer_options.get("block", None)
|
||||||
block_index = transformer_options.get("block_index", 0)
|
block_index = transformer_options.get("block_index", 0)
|
||||||
@ -476,6 +536,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
extra_options["n_heads"] = self.n_heads
|
extra_options["n_heads"] = self.n_heads
|
||||||
extra_options["dim_head"] = self.d_head
|
extra_options["dim_head"] = self.d_head
|
||||||
|
extra_options["attn_precision"] = self.attn_precision
|
||||||
|
|
||||||
if self.ff_in:
|
if self.ff_in:
|
||||||
x_skip = x
|
x_skip = x
|
||||||
@ -586,7 +647,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
def __init__(self, in_channels, n_heads, d_head,
|
def __init__(self, in_channels, n_heads, d_head,
|
||||||
depth=1, dropout=0., context_dim=None,
|
depth=1, dropout=0., context_dim=None,
|
||||||
disable_self_attn=False, use_linear=False,
|
disable_self_attn=False, use_linear=False,
|
||||||
use_checkpoint=True, dtype=None, device=None, operations=ops):
|
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if exists(context_dim) and not isinstance(context_dim, list):
|
if exists(context_dim) and not isinstance(context_dim, list):
|
||||||
context_dim = [context_dim] * depth
|
context_dim = [context_dim] * depth
|
||||||
@ -604,7 +665,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
|
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||||
for d in range(depth)]
|
for d in range(depth)]
|
||||||
)
|
)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
@ -625,7 +686,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
if not self.use_linear:
|
if not self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
x = x.movedim(1, 3).flatten(1, 2).contiguous()
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
@ -633,7 +694,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
|
||||||
if not self.use_linear:
|
if not self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
return x + x_in
|
return x + x_in
|
||||||
@ -660,6 +721,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
disable_self_attn=False,
|
disable_self_attn=False,
|
||||||
disable_temporal_crossattention=False,
|
disable_temporal_crossattention=False,
|
||||||
max_time_embed_period: int = 10000,
|
max_time_embed_period: int = 10000,
|
||||||
|
attn_precision=None,
|
||||||
dtype=None, device=None, operations=ops
|
dtype=None, device=None, operations=ops
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -672,6 +734,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
context_dim=context_dim,
|
context_dim=context_dim,
|
||||||
use_linear=use_linear,
|
use_linear=use_linear,
|
||||||
disable_self_attn=disable_self_attn,
|
disable_self_attn=disable_self_attn,
|
||||||
|
attn_precision=attn_precision,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
self.time_depth = time_depth
|
self.time_depth = time_depth
|
||||||
@ -701,6 +764,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
inner_dim=time_mix_inner_dim,
|
inner_dim=time_mix_inner_dim,
|
||||||
disable_self_attn=disable_self_attn,
|
disable_self_attn=disable_self_attn,
|
||||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
attn_precision=attn_precision,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(self.depth)
|
for _ in range(self.depth)
|
||||||
|
|||||||
962
comfy/ldm/modules/diffusionmodules/mmdit.py
Normal file
962
comfy/ldm/modules/diffusionmodules/mmdit.py
Normal file
@ -0,0 +1,962 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from .. import attention
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
def default(x, y):
|
||||||
|
if x is not None:
|
||||||
|
return x
|
||||||
|
return y
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features,
|
||||||
|
hidden_features=None,
|
||||||
|
out_features=None,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=None,
|
||||||
|
bias=True,
|
||||||
|
drop=0.,
|
||||||
|
use_conv=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
drop_probs = drop
|
||||||
|
linear_layer = partial(operations.Conv2d, kernel_size=1) if use_conv else operations.Linear
|
||||||
|
|
||||||
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.drop1 = nn.Dropout(drop_probs)
|
||||||
|
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||||
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.drop2 = nn.Dropout(drop_probs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop1(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
""" 2D Image to Patch Embedding
|
||||||
|
"""
|
||||||
|
dynamic_img_pad: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size: Optional[int] = 224,
|
||||||
|
patch_size: int = 16,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
norm_layer = None,
|
||||||
|
flatten: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
strict_img_size: bool = True,
|
||||||
|
dynamic_img_pad: bool = True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = (patch_size, patch_size)
|
||||||
|
if img_size is not None:
|
||||||
|
self.img_size = (img_size, img_size)
|
||||||
|
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
||||||
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||||
|
else:
|
||||||
|
self.img_size = None
|
||||||
|
self.grid_size = None
|
||||||
|
self.num_patches = None
|
||||||
|
|
||||||
|
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||||
|
self.flatten = flatten
|
||||||
|
self.strict_img_size = strict_img_size
|
||||||
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
|
||||||
|
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
# if self.img_size is not None:
|
||||||
|
# if self.strict_img_size:
|
||||||
|
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||||
|
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||||
|
# elif not self.dynamic_img_pad:
|
||||||
|
# _assert(
|
||||||
|
# H % self.patch_size[0] == 0,
|
||||||
|
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||||
|
# )
|
||||||
|
# _assert(
|
||||||
|
# W % self.patch_size[1] == 0,
|
||||||
|
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||||
|
# )
|
||||||
|
if self.dynamic_img_pad:
|
||||||
|
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||||
|
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||||
|
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
||||||
|
x = self.proj(x)
|
||||||
|
if self.flatten:
|
||||||
|
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
if shift is None:
|
||||||
|
shift = torch.zeros_like(scale)
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Sine/Cosine Positional Embedding Functions #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed(
|
||||||
|
embed_dim,
|
||||||
|
grid_size,
|
||||||
|
cls_token=False,
|
||||||
|
extra_tokens=0,
|
||||||
|
scaling_factor=None,
|
||||||
|
offset=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
grid_size: int of the grid height and width
|
||||||
|
return:
|
||||||
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||||
|
"""
|
||||||
|
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||||
|
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||||
|
grid = np.stack(grid, axis=0)
|
||||||
|
if scaling_factor is not None:
|
||||||
|
grid = grid / scaling_factor
|
||||||
|
if offset is not None:
|
||||||
|
grid = grid - offset
|
||||||
|
|
||||||
|
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||||
|
if cls_token and extra_tokens > 0:
|
||||||
|
pos_embed = np.concatenate(
|
||||||
|
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||||
|
)
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
|
||||||
|
# use half of dimensions to encode grid_h
|
||||||
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||||
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||||
|
|
||||||
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,)
|
||||||
|
out: (M, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||||
|
omega /= embed_dim / 2.0
|
||||||
|
omega = 1.0 / 10000**omega # (D/2,)
|
||||||
|
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||||
|
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
|
||||||
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32):
|
||||||
|
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
|
||||||
|
omega /= embed_dim / 2.0
|
||||||
|
omega = 1.0 / 10000**omega # (D/2,)
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||||
|
emb_sin = torch.sin(out) # (M, D/2)
|
||||||
|
emb_cos = torch.cos(out) # (M, D/2)
|
||||||
|
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32):
|
||||||
|
small = min(h, w)
|
||||||
|
val_h = (h / small) * val_magnitude
|
||||||
|
val_w = (w / small) * val_magnitude
|
||||||
|
grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij')
|
||||||
|
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||||
|
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||||
|
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Embedding Layers for Timesteps and Class Labels #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period)
|
||||||
|
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||||
|
/ half
|
||||||
|
)
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat(
|
||||||
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||||
|
)
|
||||||
|
if torch.is_floating_point(t):
|
||||||
|
embedding = embedding.to(dtype=t.dtype)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def forward(self, t, dtype, **kwargs):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class VectorEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds a flat vector of dimension input_dim
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
emb = self.mlp(x)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Core DiT Model #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def split_qkv(qkv, head_dim):
|
||||||
|
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||||
|
return qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
|
def optimized_attention(qkv, num_heads):
|
||||||
|
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_scale: Optional[float] = None,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
attn_mode: str = "xformers",
|
||||||
|
pre_only: bool = False,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
|
rmsnorm: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
|
if not pre_only:
|
||||||
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
assert attn_mode in self.ATTENTION_MODES
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
self.pre_only = pre_only
|
||||||
|
|
||||||
|
if qk_norm == "rms":
|
||||||
|
self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||||
|
self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||||
|
elif qk_norm == "ln":
|
||||||
|
self.ln_q = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||||
|
self.ln_k = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||||
|
elif qk_norm is None:
|
||||||
|
self.ln_q = nn.Identity()
|
||||||
|
self.ln_k = nn.Identity()
|
||||||
|
else:
|
||||||
|
raise ValueError(qk_norm)
|
||||||
|
|
||||||
|
def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, L, C = x.shape
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = split_qkv(qkv, self.head_dim)
|
||||||
|
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
||||||
|
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
||||||
|
return (q, k, v)
|
||||||
|
|
||||||
|
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert not self.pre_only
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
qkv = self.pre_attention(x)
|
||||||
|
x = optimized_attention(
|
||||||
|
qkv, num_heads=self.num_heads
|
||||||
|
)
|
||||||
|
x = self.post_attention(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the RMSNorm normalization layer.
|
||||||
|
Args:
|
||||||
|
dim (int): The dimension of the input tensor.
|
||||||
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||||
|
Attributes:
|
||||||
|
eps (float): A small value added to the denominator for numerical stability.
|
||||||
|
weight (nn.Parameter): Learnable scaling parameter.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.learnable_scale = elementwise_affine
|
||||||
|
if self.learnable_scale:
|
||||||
|
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
"""
|
||||||
|
Apply the RMSNorm normalization to the input tensor.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The normalized tensor.
|
||||||
|
"""
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass through the RMSNorm layer.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output tensor after applying RMSNorm.
|
||||||
|
"""
|
||||||
|
x = self._norm(x)
|
||||||
|
if self.learnable_scale:
|
||||||
|
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the FeedForward module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Input dimension.
|
||||||
|
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||||
|
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||||
|
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||||
|
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||||
|
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class DismantledBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A DiT block with gated adaptive layer norm (adaLN) conditioning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: str = "xformers",
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
pre_only: bool = False,
|
||||||
|
rmsnorm: bool = False,
|
||||||
|
scale_mod_only: bool = False,
|
||||||
|
swiglu: bool = False,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert attn_mode in self.ATTENTION_MODES
|
||||||
|
if not rmsnorm:
|
||||||
|
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.attn = SelfAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
pre_only=pre_only,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
rmsnorm=rmsnorm,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
if not pre_only:
|
||||||
|
if not rmsnorm:
|
||||||
|
self.norm2 = operations.LayerNorm(
|
||||||
|
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
if not pre_only:
|
||||||
|
if not swiglu:
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=hidden_size,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=lambda: nn.GELU(approximate="tanh"),
|
||||||
|
drop=0,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mlp = SwiGLUFeedForward(
|
||||||
|
dim=hidden_size,
|
||||||
|
hidden_dim=mlp_hidden_dim,
|
||||||
|
multiple_of=256,
|
||||||
|
)
|
||||||
|
self.scale_mod_only = scale_mod_only
|
||||||
|
if not scale_mod_only:
|
||||||
|
n_mods = 6 if not pre_only else 2
|
||||||
|
else:
|
||||||
|
n_mods = 4 if not pre_only else 1
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(), operations.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.pre_only = pre_only
|
||||||
|
|
||||||
|
def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.pre_only:
|
||||||
|
if not self.scale_mod_only:
|
||||||
|
(
|
||||||
|
shift_msa,
|
||||||
|
scale_msa,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
) = self.adaLN_modulation(c).chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
shift_msa = None
|
||||||
|
shift_mlp = None
|
||||||
|
(
|
||||||
|
scale_msa,
|
||||||
|
gate_msa,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
) = self.adaLN_modulation(
|
||||||
|
c
|
||||||
|
).chunk(4, dim=1)
|
||||||
|
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||||
|
return qkv, (
|
||||||
|
x,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not self.scale_mod_only:
|
||||||
|
(
|
||||||
|
shift_msa,
|
||||||
|
scale_msa,
|
||||||
|
) = self.adaLN_modulation(
|
||||||
|
c
|
||||||
|
).chunk(2, dim=1)
|
||||||
|
else:
|
||||||
|
shift_msa = None
|
||||||
|
scale_msa = self.adaLN_modulation(c)
|
||||||
|
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||||
|
return qkv, None
|
||||||
|
|
||||||
|
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
||||||
|
assert not self.pre_only
|
||||||
|
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||||
|
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||||
|
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert not self.pre_only
|
||||||
|
qkv, intermediates = self.pre_attention(x, c)
|
||||||
|
attn = optimized_attention(
|
||||||
|
qkv,
|
||||||
|
num_heads=self.attn.num_heads,
|
||||||
|
)
|
||||||
|
return self.post_attention(attn, *intermediates)
|
||||||
|
|
||||||
|
|
||||||
|
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||||
|
if use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(
|
||||||
|
_block_mixing, *args, use_reentrant=False, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _block_mixing(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _block_mixing(context, x, context_block, x_block, c):
|
||||||
|
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||||
|
|
||||||
|
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||||
|
|
||||||
|
o = []
|
||||||
|
for t in range(3):
|
||||||
|
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
|
||||||
|
qkv = tuple(o)
|
||||||
|
|
||||||
|
attn = optimized_attention(
|
||||||
|
qkv,
|
||||||
|
num_heads=x_block.attn.num_heads,
|
||||||
|
)
|
||||||
|
context_attn, x_attn = (
|
||||||
|
attn[:, : context_qkv[0].shape[1]],
|
||||||
|
attn[:, context_qkv[0].shape[1] :],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context_block.pre_only:
|
||||||
|
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||||
|
|
||||||
|
else:
|
||||||
|
context = None
|
||||||
|
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||||
|
return context, x
|
||||||
|
|
||||||
|
|
||||||
|
class JointBlock(nn.Module):
|
||||||
|
"""just a small wrapper to serve as a fsdp unit"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
pre_only = kwargs.pop("pre_only")
|
||||||
|
qk_norm = kwargs.pop("qk_norm", None)
|
||||||
|
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||||
|
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return block_mixing(
|
||||||
|
*args, context_block=self.context_block, x_block=self.x_block, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
out_channels: int,
|
||||||
|
total_out_channels: Optional[int] = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = (
|
||||||
|
operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
if (total_out_channels is None)
|
||||||
|
else operations.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class SelfAttentionContext(nn.Module):
|
||||||
|
def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
dim_head = dim // heads
|
||||||
|
inner_dim = dim
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
|
||||||
|
self.qkv = operations.Linear(dim, dim * 3, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.proj = operations.Linear(inner_dim, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = split_qkv(qkv, self.dim_head)
|
||||||
|
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
||||||
|
return self.proj(x)
|
||||||
|
|
||||||
|
class ContextProcessorBlock(nn.Module):
|
||||||
|
def __init__(self, context_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.attn = SelfAttentionContext(context_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.norm2 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.mlp = Mlp(in_features=context_size, hidden_features=(context_size * 4), act_layer=lambda: nn.GELU(approximate="tanh"), drop=0, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x += self.attn(self.norm1(x))
|
||||||
|
x += self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ContextProcessor(nn.Module):
|
||||||
|
def __init__(self, context_size, num_layers, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList([ContextProcessorBlock(context_size, dtype=dtype, device=device, operations=operations) for i in range(num_layers)])
|
||||||
|
self.norm = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i, l in enumerate(self.layers):
|
||||||
|
x = l(x)
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
class MMDiT(nn.Module):
|
||||||
|
"""
|
||||||
|
Diffusion model with a Transformer backbone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int = 32,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 4,
|
||||||
|
depth: int = 28,
|
||||||
|
# hidden_size: Optional[int] = None,
|
||||||
|
# num_heads: Optional[int] = None,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
learn_sigma: bool = False,
|
||||||
|
adm_in_channels: Optional[int] = None,
|
||||||
|
context_embedder_config: Optional[Dict] = None,
|
||||||
|
compile_core: bool = False,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
register_length: int = 0,
|
||||||
|
attn_mode: str = "torch",
|
||||||
|
rmsnorm: bool = False,
|
||||||
|
scale_mod_only: bool = False,
|
||||||
|
swiglu: bool = False,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
pos_embed_scaling_factor: Optional[float] = None,
|
||||||
|
pos_embed_offset: Optional[float] = None,
|
||||||
|
pos_embed_max_size: Optional[int] = None,
|
||||||
|
num_patches = None,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
context_processor_layers = None,
|
||||||
|
context_size = 4096,
|
||||||
|
dtype = None, #TODO
|
||||||
|
device = None,
|
||||||
|
operations = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.learn_sigma = learn_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||||
|
self.out_channels = default(out_channels, default_out_channels)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||||
|
self.pos_embed_offset = pos_embed_offset
|
||||||
|
self.pos_embed_max_size = pos_embed_max_size
|
||||||
|
|
||||||
|
# hidden_size = default(hidden_size, 64 * depth)
|
||||||
|
# num_heads = default(num_heads, hidden_size // 64)
|
||||||
|
|
||||||
|
# apply magic --> this defines a head_size of 64
|
||||||
|
self.hidden_size = 64 * depth
|
||||||
|
num_heads = depth
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
input_size,
|
||||||
|
patch_size,
|
||||||
|
in_channels,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
strict_img_size=self.pos_embed_max_size is None,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.y_embedder = None
|
||||||
|
if adm_in_channels is not None:
|
||||||
|
assert isinstance(adm_in_channels, int)
|
||||||
|
self.y_embedder = VectorEmbedder(adm_in_channels, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
if context_processor_layers is not None:
|
||||||
|
self.context_processor = ContextProcessor(context_size, context_processor_layers, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.context_processor = None
|
||||||
|
|
||||||
|
self.context_embedder = nn.Identity()
|
||||||
|
if context_embedder_config is not None:
|
||||||
|
if context_embedder_config["target"] == "torch.nn.Linear":
|
||||||
|
self.context_embedder = operations.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.register_length = register_length
|
||||||
|
if self.register_length > 0:
|
||||||
|
self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
# num_patches = self.x_embedder.num_patches
|
||||||
|
# Will use fixed sin-cos embedding:
|
||||||
|
# just use a buffer already
|
||||||
|
if num_patches is not None:
|
||||||
|
self.register_buffer(
|
||||||
|
"pos_embed",
|
||||||
|
torch.empty(1, num_patches, self.hidden_size, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.pos_embed = None
|
||||||
|
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.joint_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
JointBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
pre_only=i == depth - 1,
|
||||||
|
rmsnorm=rmsnorm,
|
||||||
|
scale_mod_only=scale_mod_only,
|
||||||
|
swiglu=swiglu,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
if compile_core:
|
||||||
|
assert False
|
||||||
|
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
|
||||||
|
|
||||||
|
def cropped_pos_embed(self, hw, device=None):
|
||||||
|
p = self.x_embedder.patch_size[0]
|
||||||
|
h, w = hw
|
||||||
|
# patched size
|
||||||
|
h = (h + 1) // p
|
||||||
|
w = (w + 1) // p
|
||||||
|
if self.pos_embed is None:
|
||||||
|
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
|
||||||
|
assert self.pos_embed_max_size is not None
|
||||||
|
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||||
|
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||||
|
top = (self.pos_embed_max_size - h) // 2
|
||||||
|
left = (self.pos_embed_max_size - w) // 2
|
||||||
|
spatial_pos_embed = rearrange(
|
||||||
|
self.pos_embed,
|
||||||
|
"1 (h w) c -> 1 h w c",
|
||||||
|
h=self.pos_embed_max_size,
|
||||||
|
w=self.pos_embed_max_size,
|
||||||
|
)
|
||||||
|
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||||
|
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||||
|
# print(spatial_pos_embed, top, left, h, w)
|
||||||
|
# # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res
|
||||||
|
# t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better
|
||||||
|
# # print(t)
|
||||||
|
# return t
|
||||||
|
return spatial_pos_embed
|
||||||
|
|
||||||
|
def unpatchify(self, x, hw=None):
|
||||||
|
"""
|
||||||
|
x: (N, T, patch_size**2 * C)
|
||||||
|
imgs: (N, H, W, C)
|
||||||
|
"""
|
||||||
|
c = self.out_channels
|
||||||
|
p = self.x_embedder.patch_size[0]
|
||||||
|
if hw is None:
|
||||||
|
h = w = int(x.shape[1] ** 0.5)
|
||||||
|
else:
|
||||||
|
h, w = hw
|
||||||
|
h = (h + 1) // p
|
||||||
|
w = (w + 1) // p
|
||||||
|
assert h * w == x.shape[1]
|
||||||
|
|
||||||
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||||
|
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||||
|
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
def forward_core_with_concat(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
c_mod: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.register_length > 0:
|
||||||
|
context = torch.cat(
|
||||||
|
(
|
||||||
|
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
||||||
|
default(context, torch.Tensor([]).type_as(x)),
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# context is B, L', D
|
||||||
|
# x is B, L, D
|
||||||
|
for block in self.joint_blocks:
|
||||||
|
context, x = block(
|
||||||
|
context,
|
||||||
|
x,
|
||||||
|
c=c_mod,
|
||||||
|
use_checkpoint=self.use_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of DiT.
|
||||||
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
|
t: (N,) tensor of diffusion timesteps
|
||||||
|
y: (N,) tensor of class labels
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.context_processor is not None:
|
||||||
|
context = self.context_processor(context)
|
||||||
|
|
||||||
|
hw = x.shape[-2:]
|
||||||
|
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
|
||||||
|
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||||
|
if y is not None and self.y_embedder is not None:
|
||||||
|
y = self.y_embedder(y) # (N, D)
|
||||||
|
c = c + y # (N, D)
|
||||||
|
|
||||||
|
if context is not None:
|
||||||
|
context = self.context_embedder(context)
|
||||||
|
|
||||||
|
x = self.forward_core_with_concat(x, c, context)
|
||||||
|
|
||||||
|
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||||
|
return x[:,:,:hw[-2],:hw[-1]]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAISignatureMMDITWrapper(MMDiT):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return super().forward(x, timesteps, context=context, y=y)
|
||||||
|
|
||||||
@ -3,7 +3,6 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import rearrange
|
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|||||||
@ -258,7 +258,7 @@ class ResBlock(TimestepBlock):
|
|||||||
else:
|
else:
|
||||||
if emb_out is not None:
|
if emb_out is not None:
|
||||||
if self.exchange_temb_dims:
|
if self.exchange_temb_dims:
|
||||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
emb_out = emb_out.movedim(1, 2)
|
||||||
h = h + emb_out
|
h = h + emb_out
|
||||||
h = self.out_layers(h)
|
h = self.out_layers(h)
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
@ -431,6 +431,7 @@ class UNetModel(nn.Module):
|
|||||||
video_kernel_size=None,
|
video_kernel_size=None,
|
||||||
disable_temporal_crossattention=False,
|
disable_temporal_crossattention=False,
|
||||||
max_ddpm_temb_period=10000,
|
max_ddpm_temb_period=10000,
|
||||||
|
attn_precision=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=ops,
|
operations=ops,
|
||||||
):
|
):
|
||||||
@ -550,13 +551,14 @@ class UNetModel(nn.Module):
|
|||||||
disable_self_attn=disable_self_attn,
|
disable_self_attn=disable_self_attn,
|
||||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
max_time_embed_period=max_ddpm_temb_period,
|
max_time_embed_period=max_ddpm_temb_period,
|
||||||
|
attn_precision=attn_precision,
|
||||||
dtype=self.dtype, device=device, operations=operations
|
dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return SpatialTransformer(
|
return SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
||||||
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_resblock(
|
def get_resblock(
|
||||||
|
|||||||
@ -29,6 +29,7 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
regular_lora = "{}.lora_up.weight".format(x)
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
A_name = None
|
A_name = None
|
||||||
|
|
||||||
@ -40,6 +41,10 @@ def load_lora(lora, to_load):
|
|||||||
A_name = diffusers_lora
|
A_name = diffusers_lora
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
mid_name = None
|
mid_name = None
|
||||||
|
elif diffusers2_lora in lora.keys():
|
||||||
|
A_name = diffusers2_lora
|
||||||
|
B_name = "{}.lora_A.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
elif transformers_lora in lora.keys():
|
elif transformers_lora in lora.keys():
|
||||||
A_name = transformers_lora
|
A_name = transformers_lora
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||||
@ -164,6 +169,7 @@ def load_lora(lora, to_load):
|
|||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
if x not in loaded_keys:
|
if x not in loaded_keys:
|
||||||
logging.warning("lora key not loaded: {}".format(x))
|
logging.warning("lora key not loaded: {}".format(x))
|
||||||
|
|
||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
@ -217,7 +223,8 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
def model_lora_keys_unet(model, key_map={}):
|
def model_lora_keys_unet(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sd = model.state_dict()
|
||||||
|
sdk = sd.keys()
|
||||||
|
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
@ -238,4 +245,17 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
if diffusers_lora_key.endswith(".to_out.0"):
|
if diffusers_lora_key.endswith(".to_out.0"):
|
||||||
diffusers_lora_key = diffusers_lora_key[:-2]
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
||||||
key_map[diffusers_lora_key] = unet_key
|
key_map[diffusers_lora_key] = unet_key
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||||
|
for i in range(model.model_config.unet_config.get("depth", 0)):
|
||||||
|
k = "transformer.transformer_blocks.{}.attn.".format(i)
|
||||||
|
qkv = "diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight".format(i)
|
||||||
|
proj = "diffusion_model.joint_blocks.{}.x_block.attn.proj.weight".format(i)
|
||||||
|
if qkv in sd:
|
||||||
|
offset = sd[qkv].shape[0] // 3
|
||||||
|
key_map["{}to_q".format(k)] = (qkv, (0, 0, offset))
|
||||||
|
key_map["{}to_k".format(k)] = (qkv, (0, offset, offset))
|
||||||
|
key_map["{}to_v".format(k)] = (qkv, (0, offset * 2, offset))
|
||||||
|
key_map["{}to_out.0".format(k)] = proj
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|||||||
@ -5,11 +5,16 @@ from comfy.ldm.cascade.stage_c import StageC
|
|||||||
from comfy.ldm.cascade.stage_b import StageB
|
from comfy.ldm.cascade.stage_b import StageB
|
||||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
|
import comfy.ldm.audio.dit
|
||||||
|
import comfy.ldm.audio.embedders
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
|
import comfy.latent_formats
|
||||||
|
import math
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
@ -17,9 +22,11 @@ class ModelType(Enum):
|
|||||||
V_PREDICTION_EDM = 3
|
V_PREDICTION_EDM = 3
|
||||||
STABLE_CASCADE = 4
|
STABLE_CASCADE = 4
|
||||||
EDM = 5
|
EDM = 5
|
||||||
|
FLOW = 6
|
||||||
|
V_PREDICTION_CONTINUOUS = 7
|
||||||
|
|
||||||
|
|
||||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
@ -32,12 +39,18 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||||
c = V_PREDICTION
|
c = V_PREDICTION
|
||||||
s = ModelSamplingContinuousEDM
|
s = ModelSamplingContinuousEDM
|
||||||
|
elif model_type == ModelType.FLOW:
|
||||||
|
c = comfy.model_sampling.CONST
|
||||||
|
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||||
elif model_type == ModelType.STABLE_CASCADE:
|
elif model_type == ModelType.STABLE_CASCADE:
|
||||||
c = EPS
|
c = EPS
|
||||||
s = StableCascadeSampling
|
s = StableCascadeSampling
|
||||||
elif model_type == ModelType.EDM:
|
elif model_type == ModelType.EDM:
|
||||||
c = EDM
|
c = EDM
|
||||||
s = ModelSamplingContinuousEDM
|
s = ModelSamplingContinuousEDM
|
||||||
|
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||||
|
c = V_PREDICTION
|
||||||
|
s = ModelSamplingContinuousV
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -60,6 +73,9 @@ class BaseModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = comfy.ops.disable_weight_init
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
if comfy.model_management.force_channels_last():
|
||||||
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
|
logging.debug("using channels last mode for diffusion model")
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_sampling = model_sampling(model_config, model_type)
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
@ -162,7 +178,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
c_concat = kwargs.get("noise_concat", None)
|
c_concat = kwargs.get("noise_concat", None)
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
out['c_concat'] = comfy.conds.CONDNoiseShape(c_concat)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -201,9 +217,6 @@ class BaseModel(torch.nn.Module):
|
|||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.get_dtype() == torch.float16:
|
|
||||||
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
|
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
unet_state_dict["v_pred"] = torch.tensor([])
|
unet_state_dict["v_pred"] = torch.tensor([])
|
||||||
|
|
||||||
@ -230,11 +243,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
area = input_shape[0] * math.prod(input_shape[2:])
|
||||||
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
||||||
else:
|
else:
|
||||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
area = input_shape[0] * math.prod(input_shape[2:])
|
||||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||||
|
|
||||||
|
|
||||||
@ -557,3 +570,60 @@ class StableCascade_B(BaseModel):
|
|||||||
out["effnet"] = comfy.conds.CONDRegular(prior)
|
out["effnet"] = comfy.conds.CONDRegular(prior)
|
||||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SD3(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
return kwargs["pooled_output"]
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def memory_required(self, input_shape):
|
||||||
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
|
dtype = self.get_dtype()
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
#TODO: this probably needs to be tweaked
|
||||||
|
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||||
|
return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
|
||||||
|
else:
|
||||||
|
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||||
|
return (area * 0.3) * (1024 * 1024)
|
||||||
|
|
||||||
|
|
||||||
|
class StableAudio1(BaseModel):
|
||||||
|
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
||||||
|
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||||
|
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||||
|
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
|
||||||
|
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
seconds_start = kwargs.get("seconds_start", 0)
|
||||||
|
seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 21.53))
|
||||||
|
|
||||||
|
seconds_start_embed = self.seconds_start_embedder([seconds_start])[0].to(device)
|
||||||
|
seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device)
|
||||||
|
|
||||||
|
global_embed = torch.cat([seconds_start_embed, seconds_total_embed], dim=-1).reshape((1, -1))
|
||||||
|
out['global_embed'] = comfy.conds.CONDRegular(global_embed)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1)
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import comfy.supported_models
|
import comfy.supported_models
|
||||||
import comfy.supported_models_base
|
import comfy.supported_models_base
|
||||||
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
@ -26,12 +27,47 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
|||||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||||
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
||||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
|
||||||
|
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix):
|
def detect_unet_config(state_dict, key_prefix):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
|
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||||
|
unet_config = {}
|
||||||
|
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
||||||
|
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
|
||||||
|
unet_config["patch_size"] = patch_size
|
||||||
|
unet_config["out_channels"] = state_dict['{}final_layer.linear.weight'.format(key_prefix)].shape[0] // (patch_size * patch_size)
|
||||||
|
|
||||||
|
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
|
||||||
|
unet_config["input_size"] = None
|
||||||
|
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
|
||||||
|
if y_key in state_dict_keys:
|
||||||
|
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
|
||||||
|
|
||||||
|
context_key = '{}context_embedder.weight'.format(key_prefix)
|
||||||
|
if context_key in state_dict_keys:
|
||||||
|
in_features = state_dict[context_key].shape[1]
|
||||||
|
out_features = state_dict[context_key].shape[0]
|
||||||
|
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
|
||||||
|
num_patches_key = '{}pos_embed'.format(key_prefix)
|
||||||
|
if num_patches_key in state_dict_keys:
|
||||||
|
num_patches = state_dict[num_patches_key].shape[1]
|
||||||
|
unet_config["num_patches"] = num_patches
|
||||||
|
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
|
||||||
|
|
||||||
|
rms_qk = '{}joint_blocks.0.context_block.attn.ln_q.weight'.format(key_prefix)
|
||||||
|
if rms_qk in state_dict_keys:
|
||||||
|
unet_config["qk_norm"] = "rms"
|
||||||
|
|
||||||
|
unet_config["pos_embed_scaling_factor"] = None #unused for inference
|
||||||
|
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||||
|
if context_processor in state_dict_keys:
|
||||||
|
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||||
|
return unet_config
|
||||||
|
|
||||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||||
unet_config = {}
|
unet_config = {}
|
||||||
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||||
@ -58,7 +94,11 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
unet_config['nhead'] = [-1, 9, 18, 18]
|
unet_config['nhead'] = [-1, 9, 18, 18]
|
||||||
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
||||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||||
|
return unet_config
|
||||||
|
|
||||||
|
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||||
|
unet_config = {}
|
||||||
|
unet_config["audio_model"] = "dit1.0"
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@ -93,6 +133,7 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
use_linear_in_transformer = False
|
use_linear_in_transformer = False
|
||||||
|
|
||||||
video_model = False
|
video_model = False
|
||||||
|
video_model_cross = False
|
||||||
|
|
||||||
current_res = 1
|
current_res = 1
|
||||||
count = 0
|
count = 0
|
||||||
@ -136,6 +177,7 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
context_dim = out[1]
|
context_dim = out[1]
|
||||||
use_linear_in_transformer = out[2]
|
use_linear_in_transformer = out[2]
|
||||||
video_model = out[3]
|
video_model = out[3]
|
||||||
|
video_model_cross = out[4]
|
||||||
else:
|
else:
|
||||||
transformer_depth.append(0)
|
transformer_depth.append(0)
|
||||||
|
|
||||||
@ -176,6 +218,7 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
unet_config["video_kernel_size"] = [3, 1, 1]
|
unet_config["video_kernel_size"] = [3, 1, 1]
|
||||||
unet_config["use_temporal_resblock"] = True
|
unet_config["use_temporal_resblock"] = True
|
||||||
unet_config["use_temporal_attention"] = True
|
unet_config["use_temporal_attention"] = True
|
||||||
|
unet_config["disable_temporal_crossattention"] = not video_model_cross
|
||||||
else:
|
else:
|
||||||
unet_config["use_temporal_resblock"] = False
|
unet_config["use_temporal_resblock"] = False
|
||||||
unet_config["use_temporal_attention"] = False
|
unet_config["use_temporal_attention"] = False
|
||||||
@ -198,6 +241,13 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
else:
|
else:
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
|
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
||||||
|
unet_key_prefix = "model.model."
|
||||||
|
else:
|
||||||
|
unet_key_prefix = "model.diffusion_model."
|
||||||
|
return unet_key_prefix
|
||||||
|
|
||||||
def convert_config(unet_config):
|
def convert_config(unet_config):
|
||||||
new_config = unet_config.copy()
|
new_config = unet_config.copy()
|
||||||
num_res_blocks = new_config.get("num_res_blocks", None)
|
num_res_blocks = new_config.get("num_res_blocks", None)
|
||||||
|
|||||||
@ -2,9 +2,9 @@ import psutil
|
|||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.utils
|
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
import platform
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
@ -83,7 +83,7 @@ def get_torch_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return torch.device("xpu")
|
return torch.device("xpu", torch.xpu.current_device())
|
||||||
else:
|
else:
|
||||||
return torch.device(torch.cuda.current_device())
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
@ -102,8 +102,8 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
stats = torch.xpu.memory_stats(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
|
||||||
mem_total_torch = mem_reserved
|
mem_total_torch = mem_reserved
|
||||||
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
@ -119,10 +119,11 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
if not args.normalvram and not args.cpu:
|
|
||||||
if lowvram_available and total_vram <= 4096:
|
try:
|
||||||
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
logging.info("pytorch version: {}".format(torch.version.__version__))
|
||||||
set_vram_to = VRAMState.LOW_VRAM
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||||
@ -166,7 +167,7 @@ if args.use_pytorch_cross_attention:
|
|||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPES = [torch.float32]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
@ -175,7 +176,7 @@ try:
|
|||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||||
VAE_DTYPE = torch.bfloat16
|
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
@ -183,17 +184,10 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
VAE_DTYPE = torch.bfloat16
|
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||||
|
|
||||||
if args.cpu_vae:
|
if args.cpu_vae:
|
||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPES = [torch.float32]
|
||||||
|
|
||||||
if args.fp16_vae:
|
|
||||||
VAE_DTYPE = torch.float16
|
|
||||||
elif args.bf16_vae:
|
|
||||||
VAE_DTYPE = torch.bfloat16
|
|
||||||
elif args.fp32_vae:
|
|
||||||
VAE_DTYPE = torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
@ -257,7 +251,6 @@ try:
|
|||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
|
||||||
logging.info("VAE dtype: {}".format(VAE_DTYPE))
|
|
||||||
|
|
||||||
current_loaded_models = []
|
current_loaded_models = []
|
||||||
|
|
||||||
@ -275,6 +268,7 @@ class LoadedModel:
|
|||||||
self.device = model.load_device
|
self.device = model.load_device
|
||||||
self.weights_loaded = False
|
self.weights_loaded = False
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
|
self.currently_used = True
|
||||||
|
|
||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
@ -285,7 +279,7 @@ class LoadedModel:
|
|||||||
else:
|
else:
|
||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
def model_load(self, lowvram_model_memory=0):
|
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
||||||
patch_model_to = self.device
|
patch_model_to = self.device
|
||||||
|
|
||||||
self.model.model_patches_to(self.device)
|
self.model.model_patches_to(self.device)
|
||||||
@ -295,7 +289,7 @@ class LoadedModel:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if lowvram_model_memory > 0 and load_weights:
|
if lowvram_model_memory > 0 and load_weights:
|
||||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
|
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
else:
|
else:
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -304,11 +298,16 @@ class LoadedModel:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||||
|
|
||||||
self.weights_loaded = True
|
self.weights_loaded = True
|
||||||
return self.real_model
|
return self.real_model
|
||||||
|
|
||||||
|
def should_reload_model(self, force_patch_weights=False):
|
||||||
|
if force_patch_weights and self.model.lowvram_patch_counter > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def model_unload(self, unpatch_weights=True):
|
def model_unload(self, unpatch_weights=True):
|
||||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||||
self.model.model_patches_to(self.model.offload_device)
|
self.model.model_patches_to(self.model.offload_device)
|
||||||
@ -359,6 +358,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
if shift_model.device == device:
|
if shift_model.device == device:
|
||||||
if shift_model not in keep_loaded:
|
if shift_model not in keep_loaded:
|
||||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
|
shift_model.currently_used = False
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
@ -379,7 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
if mem_free_torch > mem_free_total * 0.25:
|
if mem_free_torch > mem_free_total * 0.25:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
@ -391,12 +391,23 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
models_already_loaded = []
|
models_already_loaded = []
|
||||||
for x in models:
|
for x in models:
|
||||||
loaded_model = LoadedModel(x)
|
loaded_model = LoadedModel(x)
|
||||||
|
loaded = None
|
||||||
|
|
||||||
if loaded_model in current_loaded_models:
|
try:
|
||||||
index = current_loaded_models.index(loaded_model)
|
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
except:
|
||||||
models_already_loaded.append(loaded_model)
|
loaded_model_index = None
|
||||||
else:
|
|
||||||
|
if loaded_model_index is not None:
|
||||||
|
loaded = current_loaded_models[loaded_model_index]
|
||||||
|
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
|
||||||
|
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||||
|
loaded = None
|
||||||
|
else:
|
||||||
|
loaded.currently_used = True
|
||||||
|
models_already_loaded.append(loaded)
|
||||||
|
|
||||||
|
if loaded is None:
|
||||||
if hasattr(x, "model"):
|
if hasattr(x, "model"):
|
||||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||||
models_to_load.append(loaded_model)
|
models_to_load.append(loaded_model)
|
||||||
@ -436,15 +447,13 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||||
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
|
||||||
else:
|
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = 0
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 64 * 1024 * 1024
|
lowvram_model_memory = 64 * 1024 * 1024
|
||||||
|
|
||||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -452,6 +461,16 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
return load_models_gpu([model])
|
return load_models_gpu([model])
|
||||||
|
|
||||||
|
def loaded_models(only_currently_used=False):
|
||||||
|
output = []
|
||||||
|
for m in current_loaded_models:
|
||||||
|
if only_currently_used:
|
||||||
|
if not m.currently_used:
|
||||||
|
continue
|
||||||
|
|
||||||
|
output.append(m.model)
|
||||||
|
return output
|
||||||
|
|
||||||
def cleanup_models(keep_clone_weights_loaded=False):
|
def cleanup_models(keep_clone_weights_loaded=False):
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
@ -552,8 +571,6 @@ def text_encoder_device():
|
|||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||||
if is_intel_xpu():
|
|
||||||
return torch.device("cpu")
|
|
||||||
if should_use_fp16(prioritize_performance=False):
|
if should_use_fp16(prioritize_performance=False):
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
@ -594,9 +611,22 @@ def vae_offload_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def vae_dtype():
|
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||||
global VAE_DTYPE
|
global VAE_DTYPES
|
||||||
return VAE_DTYPE
|
if args.fp16_vae:
|
||||||
|
return torch.float16
|
||||||
|
elif args.bf16_vae:
|
||||||
|
return torch.bfloat16
|
||||||
|
elif args.fp32_vae:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
for d in allowed_dtypes:
|
||||||
|
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
||||||
|
return d
|
||||||
|
if d in VAE_DTYPES:
|
||||||
|
return d
|
||||||
|
|
||||||
|
return VAE_DTYPES[0]
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
@ -614,11 +644,46 @@ def supports_dtype(device, dtype): #TODO
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def supports_cast(device, dtype): #TODO
|
||||||
|
if dtype == torch.float32:
|
||||||
|
return True
|
||||||
|
if dtype == torch.float16:
|
||||||
|
return True
|
||||||
|
if is_device_mps(device):
|
||||||
|
return False
|
||||||
|
if directml_enabled: #TODO: test this
|
||||||
|
return False
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
return True
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return True
|
||||||
|
if dtype == torch.float8_e5m2:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def device_supports_non_blocking(device):
|
def device_supports_non_blocking(device):
|
||||||
if is_device_mps(device):
|
if is_device_mps(device):
|
||||||
return False #pytorch bug? mps doesn't support non blocking
|
return False #pytorch bug? mps doesn't support non blocking
|
||||||
|
if is_intel_xpu():
|
||||||
|
return False
|
||||||
|
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||||
|
return False
|
||||||
|
if directml_enabled:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def device_should_use_non_blocking(device):
|
||||||
|
if not device_supports_non_blocking(device):
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||||
|
|
||||||
|
def force_channels_last():
|
||||||
|
if args.force_channels_last:
|
||||||
|
return True
|
||||||
|
|
||||||
|
#TODO
|
||||||
return False
|
return False
|
||||||
# return True #TODO: figure out why this causes issues
|
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
device_supports_cast = False
|
device_supports_cast = False
|
||||||
@ -630,7 +695,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
device_supports_cast = True
|
device_supports_cast = True
|
||||||
|
|
||||||
non_blocking = device_supports_non_blocking(device)
|
non_blocking = device_should_use_non_blocking(device)
|
||||||
|
|
||||||
if device_supports_cast:
|
if device_supports_cast:
|
||||||
if copy:
|
if copy:
|
||||||
@ -671,8 +736,22 @@ def pytorch_attention_flash_attention():
|
|||||||
#TODO: more reliable way of checking for flash attention?
|
#TODO: more reliable way of checking for flash attention?
|
||||||
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||||
return True
|
return True
|
||||||
|
if is_intel_xpu():
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def force_upcast_attention_dtype():
|
||||||
|
upcast = args.force_upcast_attention
|
||||||
|
try:
|
||||||
|
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
||||||
|
upcast = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if upcast:
|
||||||
|
return torch.float32
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
@ -688,10 +767,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
stats = torch.xpu.memory_stats(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
mem_allocated = stats['allocated_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
|
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||||
|
mem_free_total = mem_free_xpu + mem_free_torch
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
@ -845,6 +924,7 @@ def unload_all_models():
|
|||||||
|
|
||||||
|
|
||||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||||
|
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
|
|||||||
@ -6,17 +6,29 @@ import uuid
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.types import UnetWrapperFunction
|
||||||
|
|
||||||
def apply_weight_decompose(dora_scale, weight):
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
weight.transpose(0, 1)
|
weight_calc.transpose(0, 1)
|
||||||
.reshape(weight.shape[1], -1)
|
.reshape(weight_calc.shape[1], -1)
|
||||||
.norm(dim=1, keepdim=True)
|
.norm(dim=1, keepdim=True)
|
||||||
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
return weight * (dora_scale / weight_norm)
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
@ -58,14 +70,13 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
self.model_lowvram = False
|
self.model_lowvram = False
|
||||||
|
self.lowvram_patch_counter = 0
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
return self.size
|
return self.size
|
||||||
model_sd = self.model.state_dict()
|
|
||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
self.model_keys = set(model_sd.keys())
|
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
@ -77,7 +88,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
n.model_keys = self.model_keys
|
|
||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
return n
|
return n
|
||||||
@ -116,7 +126,7 @@ class ModelPatcher:
|
|||||||
if disable_cfg1_optimization:
|
if disable_cfg1_optimization:
|
||||||
self.model_options["disable_cfg1_optimization"] = True
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
def set_model_denoise_mask_function(self, denoise_mask_function):
|
def set_model_denoise_mask_function(self, denoise_mask_function):
|
||||||
@ -197,12 +207,20 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
p = set()
|
p = set()
|
||||||
|
model_sd = self.model.state_dict()
|
||||||
for k in patches:
|
for k in patches:
|
||||||
if k in self.model_keys:
|
offset = None
|
||||||
|
if isinstance(k, str):
|
||||||
|
key = k
|
||||||
|
else:
|
||||||
|
offset = k[1]
|
||||||
|
key = k[0]
|
||||||
|
|
||||||
|
if key in model_sd:
|
||||||
p.add(k)
|
p.add(k)
|
||||||
current_patches = self.patches.get(k, [])
|
current_patches = self.patches.get(key, [])
|
||||||
current_patches.append((strength_patch, patches[k], strength_model))
|
current_patches.append((strength_patch, patches[k], strength_model, offset))
|
||||||
self.patches[k] = current_patches
|
self.patches[key] = current_patches
|
||||||
|
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
return list(p)
|
return list(p)
|
||||||
@ -272,7 +290,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
|
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||||
self.patch_model(device_to, patch_weights=False)
|
self.patch_model(device_to, patch_weights=False)
|
||||||
|
|
||||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||||
@ -284,6 +302,7 @@ class ModelPatcher:
|
|||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||||
|
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
|
patch_counter = 0
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
@ -296,9 +315,17 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if lowvram_weight:
|
if lowvram_weight:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
if force_patch_weights:
|
||||||
|
self.patch_weight_to_device(weight_key)
|
||||||
|
else:
|
||||||
|
m.weight_function = LowVramPatch(weight_key, self)
|
||||||
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
if force_patch_weights:
|
||||||
|
self.patch_weight_to_device(bias_key)
|
||||||
|
else:
|
||||||
|
m.bias_function = LowVramPatch(bias_key, self)
|
||||||
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@ -308,16 +335,23 @@ class ModelPatcher:
|
|||||||
self.patch_weight_to_device(bias_key, device_to)
|
self.patch_weight_to_device(bias_key, device_to)
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
mem_counter += comfy.model_management.module_size(m)
|
mem_counter += comfy.model_management.module_size(m)
|
||||||
logging.debug("lowvram: loaded module regularly {}".format(m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
|
||||||
self.model_lowvram = True
|
self.model_lowvram = True
|
||||||
|
self.lowvram_patch_counter = patch_counter
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
def calculate_weight(self, patches, weight, key):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
alpha = p[0]
|
strength = p[0]
|
||||||
v = p[1]
|
v = p[1]
|
||||||
strength_model = p[2]
|
strength_model = p[2]
|
||||||
|
offset = p[3]
|
||||||
|
|
||||||
|
old_weight = None
|
||||||
|
if offset is not None:
|
||||||
|
old_weight = weight
|
||||||
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||||
|
|
||||||
if strength_model != 1.0:
|
if strength_model != 1.0:
|
||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
@ -333,26 +367,31 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if patch_type == "diff":
|
if patch_type == "diff":
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
if alpha != 0.0:
|
if strength != 0.0:
|
||||||
if w1.shape != weight.shape:
|
if w1.shape != weight.shape:
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
weight += strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||||
dora_scale = v[4]
|
dora_scale = v[4]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
try:
|
try:
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
||||||
|
else:
|
||||||
|
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "lokr":
|
elif patch_type == "lokr":
|
||||||
@ -389,19 +428,26 @@ class ModelPatcher:
|
|||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
if v[2] is not None and dim is not None:
|
if v[2] is not None and dim is not None:
|
||||||
alpha *= v[2] / dim
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
||||||
|
else:
|
||||||
|
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "loha":
|
elif patch_type == "loha":
|
||||||
w1a = v[0]
|
w1a = v[0]
|
||||||
w1b = v[1]
|
w1b = v[1]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / w1b.shape[0]
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
w2a = v[3]
|
w2a = v[3]
|
||||||
w2b = v[4]
|
w2b = v[4]
|
||||||
dora_scale = v[7]
|
dora_scale = v[7]
|
||||||
@ -424,14 +470,18 @@ class ModelPatcher:
|
|||||||
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
||||||
|
else:
|
||||||
|
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
if v[4] is not None:
|
if v[4] is not None:
|
||||||
alpha *= v[4] / v[0].shape[0]
|
alpha = v[4] / v[0].shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
dora_scale = v[5]
|
dora_scale = v[5]
|
||||||
|
|
||||||
@ -441,14 +491,19 @@ class ModelPatcher:
|
|||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
|
||||||
|
else:
|
||||||
|
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
@ -462,6 +517,7 @@ class ModelPatcher:
|
|||||||
m.bias_function = None
|
m.bias_function = None
|
||||||
|
|
||||||
self.model_lowvram = False
|
self.model_lowvram = False
|
||||||
|
self.lowvram_patch_counter = 0
|
||||||
|
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
|
|||||||
@ -33,6 +33,19 @@ class EDM(V_PREDICTION):
|
|||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
|
class CONST:
|
||||||
|
def calculate_input(self, sigma, noise):
|
||||||
|
return noise
|
||||||
|
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input - model_output * sigma
|
||||||
|
|
||||||
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
|
return sigma * noise + (1.0 - sigma) * latent_image
|
||||||
|
|
||||||
|
def inverse_noise_scaling(self, sigma, latent):
|
||||||
|
return latent / (1.0 - sigma)
|
||||||
|
|
||||||
class ModelSamplingDiscrete(torch.nn.Module):
|
class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
@ -104,6 +117,12 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
percent = 1.0 - percent
|
percent = 1.0 - percent
|
||||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||||
|
|
||||||
|
class ModelSamplingDiscreteEDM(ModelSamplingDiscrete):
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return 0.25 * sigma.log()
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
return (timestep / 0.25).exp()
|
||||||
|
|
||||||
class ModelSamplingContinuousEDM(torch.nn.Module):
|
class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
@ -149,6 +168,56 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
|||||||
log_sigma_min = math.log(self.sigma_min)
|
log_sigma_min = math.log(self.sigma_min)
|
||||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSamplingContinuousV(ModelSamplingContinuousEDM):
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return sigma.atan() / math.pi * 2
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
return (timestep * math.pi / 2).tan()
|
||||||
|
|
||||||
|
|
||||||
|
def time_snr_shift(alpha, t):
|
||||||
|
if alpha == 1.0:
|
||||||
|
return t
|
||||||
|
return alpha * t / (1 + (alpha - 1) * t)
|
||||||
|
|
||||||
|
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||||
|
def __init__(self, model_config=None):
|
||||||
|
super().__init__()
|
||||||
|
if model_config is not None:
|
||||||
|
sampling_settings = model_config.sampling_settings
|
||||||
|
else:
|
||||||
|
sampling_settings = {}
|
||||||
|
|
||||||
|
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
|
||||||
|
|
||||||
|
def set_parameters(self, shift=1.0, timesteps=1000):
|
||||||
|
self.shift = shift
|
||||||
|
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||||
|
self.register_buffer('sigmas', ts)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return sigma * 1000
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
return time_snr_shift(self.shift, timestep / 1000)
|
||||||
|
|
||||||
|
def percent_to_sigma(self, percent):
|
||||||
|
if percent <= 0.0:
|
||||||
|
return 1.0
|
||||||
|
if percent >= 1.0:
|
||||||
|
return 0.0
|
||||||
|
return 1.0 - percent
|
||||||
|
|
||||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
43
comfy/ops.py
43
comfy/ops.py
@ -21,7 +21,7 @@ import comfy.model_management
|
|||||||
|
|
||||||
def cast_bias_weight(s, input):
|
def cast_bias_weight(s, input):
|
||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||||
if s.bias_function is not None:
|
if s.bias_function is not None:
|
||||||
@ -51,6 +51,20 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
@ -133,6 +147,27 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input, output_size=None):
|
||||||
|
num_spatial_dims = 1
|
||||||
|
output_padding = self._output_padding(
|
||||||
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.conv_transpose1d(
|
||||||
|
input, weight, bias, self.stride, self.padding,
|
||||||
|
output_padding, self.groups, self.dilation)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def conv_nd(s, dims, *args, **kwargs):
|
def conv_nd(s, dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
@ -147,6 +182,9 @@ class manual_cast(disable_weight_init):
|
|||||||
class Linear(disable_weight_init.Linear):
|
class Linear(disable_weight_init.Linear):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class Conv1d(disable_weight_init.Conv1d):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|
||||||
class Conv2d(disable_weight_init.Conv2d):
|
class Conv2d(disable_weight_init.Conv2d):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
@ -161,3 +199,6 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|||||||
22
comfy/sa_t5.py
Normal file
22
comfy/sa_t5.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
import comfy.t5
|
||||||
|
import os
|
||||||
|
|
||||||
|
class T5BaseModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||||
|
|
||||||
|
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
||||||
|
|
||||||
|
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||||
|
|
||||||
|
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||||
@ -24,6 +24,12 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
|||||||
noises = torch.cat(noises, axis=0)
|
noises = torch.cat(noises, axis=0)
|
||||||
return noises
|
return noises
|
||||||
|
|
||||||
|
def fix_empty_latent_channels(model, latent_image):
|
||||||
|
latent_channels = model.get_model_object("latent_format").latent_channels #Resize the empty latent image so it has the right number of channels
|
||||||
|
if latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||||
|
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_channels, dim=1)
|
||||||
|
return latent_image
|
||||||
|
|
||||||
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||||
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
|
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
|
||||||
return model, positive, negative, noise_mask, []
|
return model, positive, negative, noise_mask, []
|
||||||
|
|||||||
@ -8,7 +8,8 @@ import logging
|
|||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
|
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
dims = tuple(x_in.shape[2:])
|
||||||
|
area = None
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
|
|
||||||
if 'timestep_start' in conds:
|
if 'timestep_start' in conds:
|
||||||
@ -20,11 +21,16 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
if timestep_in[0] < timestep_end:
|
if timestep_in[0] < timestep_end:
|
||||||
return None
|
return None
|
||||||
if 'area' in conds:
|
if 'area' in conds:
|
||||||
area = conds['area']
|
area = list(conds['area'])
|
||||||
if 'strength' in conds:
|
if 'strength' in conds:
|
||||||
strength = conds['strength']
|
strength = conds['strength']
|
||||||
|
|
||||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
input_x = x_in
|
||||||
|
if area is not None:
|
||||||
|
for i in range(len(dims)):
|
||||||
|
area[i] = min(input_x.shape[i + 2] - area[len(dims) + i], area[i])
|
||||||
|
input_x = input_x.narrow(i + 2, area[len(dims) + i], area[i])
|
||||||
|
|
||||||
if 'mask' in conds:
|
if 'mask' in conds:
|
||||||
# Scale the mask to the size of the input
|
# Scale the mask to the size of the input
|
||||||
# The mask should have been resized as we began the sampling process
|
# The mask should have been resized as we began the sampling process
|
||||||
@ -32,28 +38,30 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
if "mask_strength" in conds:
|
if "mask_strength" in conds:
|
||||||
mask_strength = conds["mask_strength"]
|
mask_strength = conds["mask_strength"]
|
||||||
mask = conds['mask']
|
mask = conds['mask']
|
||||||
assert(mask.shape[1] == x_in.shape[2])
|
assert(mask.shape[1:] == x_in.shape[2:])
|
||||||
assert(mask.shape[2] == x_in.shape[3])
|
|
||||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
mask = mask[:input_x.shape[0]]
|
||||||
|
if area is not None:
|
||||||
|
for i in range(len(dims)):
|
||||||
|
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
||||||
|
|
||||||
|
mask = mask * mask_strength
|
||||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||||
else:
|
else:
|
||||||
mask = torch.ones_like(input_x)
|
mask = torch.ones_like(input_x)
|
||||||
mult = mask * strength
|
mult = mask * strength
|
||||||
|
|
||||||
if 'mask' not in conds:
|
if 'mask' not in conds and area is not None:
|
||||||
rr = 8
|
rr = 8
|
||||||
if area[2] != 0:
|
for i in range(len(dims)):
|
||||||
for t in range(rr):
|
if area[len(dims) + i] != 0:
|
||||||
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
for t in range(rr):
|
||||||
if (area[0] + area[2]) < x_in.shape[2]:
|
m = mult.narrow(i + 2, t, 1)
|
||||||
for t in range(rr):
|
m *= ((1.0/rr) * (t + 1))
|
||||||
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
|
||||||
if area[3] != 0:
|
for t in range(rr):
|
||||||
for t in range(rr):
|
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
|
||||||
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
m *= ((1.0/rr) * (t + 1))
|
||||||
if (area[1] + area[3]) < x_in.shape[3]:
|
|
||||||
for t in range(rr):
|
|
||||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
|
||||||
|
|
||||||
conditioning = {}
|
conditioning = {}
|
||||||
model_conds = conds["model_conds"]
|
model_conds = conds["model_conds"]
|
||||||
@ -219,8 +227,19 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
|
|||||||
|
|
||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
cond_index = cond_or_uncond[o]
|
cond_index = cond_or_uncond[o]
|
||||||
out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
a = area[o]
|
||||||
out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
if a is None:
|
||||||
|
out_conds[cond_index] += output[o] * mult[o]
|
||||||
|
out_counts[cond_index] += mult[o]
|
||||||
|
else:
|
||||||
|
out_c = out_conds[cond_index]
|
||||||
|
out_cts = out_counts[cond_index]
|
||||||
|
dims = len(a) // 2
|
||||||
|
for i in range(dims):
|
||||||
|
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||||
|
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||||
|
out_c += output[o] * mult[o]
|
||||||
|
out_cts += mult[o]
|
||||||
|
|
||||||
for i in range(len(out_conds)):
|
for i in range(len(out_conds)):
|
||||||
out_conds[i] /= out_counts[i]
|
out_conds[i] /= out_counts[i]
|
||||||
@ -335,7 +354,7 @@ def get_mask_aabb(masks):
|
|||||||
|
|
||||||
return bounding_boxes, is_empty
|
return bounding_boxes, is_empty
|
||||||
|
|
||||||
def resolve_areas_and_cond_masks(conditions, h, w, device):
|
def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
||||||
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
||||||
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
||||||
for i in range(len(conditions)):
|
for i in range(len(conditions)):
|
||||||
@ -344,7 +363,14 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||||||
area = c['area']
|
area = c['area']
|
||||||
if area[0] == "percentage":
|
if area[0] == "percentage":
|
||||||
modified = c.copy()
|
modified = c.copy()
|
||||||
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
|
a = area[1:]
|
||||||
|
a_len = len(a) // 2
|
||||||
|
area = ()
|
||||||
|
for d in range(len(dims)):
|
||||||
|
area += (max(1, round(a[d] * dims[d])),)
|
||||||
|
for d in range(len(dims)):
|
||||||
|
area += (round(a[d + a_len] * dims[d]),)
|
||||||
|
|
||||||
modified['area'] = area
|
modified['area'] = area
|
||||||
c = modified
|
c = modified
|
||||||
conditions[i] = c
|
conditions[i] = c
|
||||||
@ -353,12 +379,12 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||||||
mask = c['mask']
|
mask = c['mask']
|
||||||
mask = mask.to(device=device)
|
mask = mask.to(device=device)
|
||||||
modified = c.copy()
|
modified = c.copy()
|
||||||
if len(mask.shape) == 2:
|
if len(mask.shape) == len(dims):
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
if mask.shape[1] != h or mask.shape[2] != w:
|
if mask.shape[1:] != dims:
|
||||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
|
||||||
|
|
||||||
if modified.get("set_area_to_bounds", False):
|
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
||||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||||
boxes, is_empty = get_mask_aabb(bounds)
|
boxes, is_empty = get_mask_aabb(bounds)
|
||||||
if is_empty[0]:
|
if is_empty[0]:
|
||||||
@ -375,7 +401,11 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
|
|||||||
modified['mask'] = mask
|
modified['mask'] = mask
|
||||||
conditions[i] = modified
|
conditions[i] = modified
|
||||||
|
|
||||||
def create_cond_with_same_area_if_none(conds, c):
|
def resolve_areas_and_cond_masks(conditions, h, w, device):
|
||||||
|
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
|
||||||
|
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
|
||||||
|
|
||||||
|
def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
|
||||||
if 'area' not in c:
|
if 'area' not in c:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -479,7 +509,10 @@ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwar
|
|||||||
params = x.copy()
|
params = x.copy()
|
||||||
params["device"] = device
|
params["device"] = device
|
||||||
params["noise"] = noise
|
params["noise"] = noise
|
||||||
params["width"] = params.get("width", noise.shape[3] * 8)
|
default_width = None
|
||||||
|
if len(noise.shape) >= 4: #TODO: 8 multiple should be set by the model
|
||||||
|
default_width = noise.shape[3] * 8
|
||||||
|
params["width"] = params.get("width", default_width)
|
||||||
params["height"] = params.get("height", noise.shape[2] * 8)
|
params["height"] = params.get("height", noise.shape[2] * 8)
|
||||||
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
||||||
for k in kwargs:
|
for k in kwargs:
|
||||||
@ -539,6 +572,9 @@ class KSAMPLER(Sampler):
|
|||||||
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||||
if sampler_name == "dpm_fast":
|
if sampler_name == "dpm_fast":
|
||||||
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
|
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return noise
|
||||||
|
|
||||||
sigma_min = sigmas[-1]
|
sigma_min = sigmas[-1]
|
||||||
if sigma_min == 0:
|
if sigma_min == 0:
|
||||||
sigma_min = sigmas[-2]
|
sigma_min = sigmas[-2]
|
||||||
@ -547,6 +583,9 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
|||||||
sampler_function = dpm_fast_function
|
sampler_function = dpm_fast_function
|
||||||
elif sampler_name == "dpm_adaptive":
|
elif sampler_name == "dpm_adaptive":
|
||||||
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
|
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return noise
|
||||||
|
|
||||||
sigma_min = sigmas[-1]
|
sigma_min = sigmas[-1]
|
||||||
if sigma_min == 0:
|
if sigma_min == 0:
|
||||||
sigma_min = sigmas[-2]
|
sigma_min = sigmas[-2]
|
||||||
@ -561,7 +600,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
|||||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||||
for k in conds:
|
for k in conds:
|
||||||
conds[k] = conds[k][:]
|
conds[k] = conds[k][:]
|
||||||
resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device)
|
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
||||||
|
|
||||||
for k in conds:
|
for k in conds:
|
||||||
calculate_start_end_timesteps(model, conds[k])
|
calculate_start_end_timesteps(model, conds[k])
|
||||||
|
|||||||
176
comfy/sd.py
176
comfy/sd.py
@ -6,7 +6,7 @@ from comfy import model_management
|
|||||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
from .ldm.cascade.stage_a import StageA
|
from .ldm.cascade.stage_a import StageA
|
||||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||||
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -14,12 +14,13 @@ import comfy.utils
|
|||||||
from . import clip_vision
|
from . import clip_vision
|
||||||
from . import gligen
|
from . import gligen
|
||||||
from . import diffusers_convert
|
from . import diffusers_convert
|
||||||
from . import model_base
|
|
||||||
from . import model_detection
|
from . import model_detection
|
||||||
|
|
||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
|
from . import sd3_clip
|
||||||
|
from . import sa_t5
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -98,13 +99,19 @@ class CLIP:
|
|||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = offload_device
|
params['device'] = offload_device
|
||||||
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
|
params['dtype'] = dtype
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
|
for dt in self.cond_stage_model.dtypes:
|
||||||
|
if not model_management.supports_cast(load_device, dt):
|
||||||
|
load_device = offload_device
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
|
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@ -168,8 +175,10 @@ class VAE:
|
|||||||
self.downscale_ratio = 8
|
self.downscale_ratio = 8
|
||||||
self.upscale_ratio = 8
|
self.upscale_ratio = 8
|
||||||
self.latent_channels = 4
|
self.latent_channels = 4
|
||||||
|
self.output_channels = 3
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
@ -181,7 +190,8 @@ class VAE:
|
|||||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||||
elif "taesd_decoder.1.weight" in sd:
|
elif "taesd_decoder.1.weight" in sd:
|
||||||
self.first_stage_model = comfy.taesd.taesd.TAESD()
|
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||||
|
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
||||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||||
self.first_stage_model = StageA()
|
self.first_stage_model = StageA()
|
||||||
self.downscale_ratio = 4
|
self.downscale_ratio = 4
|
||||||
@ -210,7 +220,7 @@ class VAE:
|
|||||||
self.first_stage_model = StageC_coder()
|
self.first_stage_model = StageC_coder()
|
||||||
self.downscale_ratio = 32
|
self.downscale_ratio = 32
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
else:
|
elif "decoder.conv_in.weight" in sd:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
#default SD1.x/SD2.x VAE parameters
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
|
|
||||||
@ -226,6 +236,21 @@ class VAE:
|
|||||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||||
|
elif "decoder.layers.0.weight_v" in sd:
|
||||||
|
self.first_stage_model = AudioOobleckVAE()
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||||
|
self.latent_channels = 64
|
||||||
|
self.output_channels = 2
|
||||||
|
self.upscale_ratio = 2048
|
||||||
|
self.downscale_ratio = 2048
|
||||||
|
self.process_output = lambda audio: audio
|
||||||
|
self.process_input = lambda audio: audio
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
else:
|
||||||
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
|
self.first_stage_model = None
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||||
self.first_stage_model = self.first_stage_model.eval()
|
self.first_stage_model = self.first_stage_model.eval()
|
||||||
@ -242,20 +267,21 @@ class VAE:
|
|||||||
self.device = device
|
self.device = device
|
||||||
offload_device = model_management.vae_offload_device()
|
offload_device = model_management.vae_offload_device()
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = model_management.vae_dtype()
|
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||||
self.vae_dtype = dtype
|
self.vae_dtype = dtype
|
||||||
self.first_stage_model.to(self.vae_dtype)
|
self.first_stage_model.to(self.vae_dtype)
|
||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
|
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
|
dims = pixels.shape[1:-1]
|
||||||
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
|
for d in range(len(dims)):
|
||||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
|
||||||
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
|
x_offset = (dims[d] % self.downscale_ratio) // 2
|
||||||
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
|
if x != dims[d]:
|
||||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||||
return pixels
|
return pixels
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
@ -293,7 +319,7 @@ class VAE:
|
|||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
|
pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||||
@ -318,7 +344,7 @@ class VAE:
|
|||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
|
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||||
@ -360,6 +386,8 @@ def load_style_model(ckpt_path):
|
|||||||
class CLIPType(Enum):
|
class CLIPType(Enum):
|
||||||
STABLE_DIFFUSION = 1
|
STABLE_DIFFUSION = 1
|
||||||
STABLE_CASCADE = 2
|
STABLE_CASCADE = 2
|
||||||
|
SD3 = 3
|
||||||
|
STABLE_AUDIO = 4
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
@ -389,12 +417,26 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||||
clip_target.clip = sd2_clip.SD2ClipModel
|
clip_target.clip = sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||||
|
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
||||||
|
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
|
||||||
|
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||||
|
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||||
|
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||||
|
clip_target.clip = sa_t5.SAT5Model
|
||||||
|
clip_target.tokenizer = sa_t5.SAT5Tokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
else:
|
elif len(clip_data) == 2:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
if clip_type == CLIPType.SD3:
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||||
|
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||||
|
else:
|
||||||
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
elif len(clip_data) == 3:
|
||||||
|
clip_target.clip = sd3_clip.SD3ClipModel
|
||||||
|
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
@ -414,6 +456,8 @@ def load_gligen(ckpt_path):
|
|||||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||||
|
|
||||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||||
|
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
||||||
|
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
||||||
#TODO: this function is a mess and should be removed eventually
|
#TODO: this function is a mess and should be removed eventually
|
||||||
if config is None:
|
if config is None:
|
||||||
with open(config_path, 'r') as stream:
|
with open(config_path, 'r') as stream:
|
||||||
@ -421,81 +465,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
model_config_params = config['model']['params']
|
model_config_params = config['model']['params']
|
||||||
clip_config = model_config_params['cond_stage_config']
|
clip_config = model_config_params['cond_stage_config']
|
||||||
scale_factor = model_config_params['scale_factor']
|
scale_factor = model_config_params['scale_factor']
|
||||||
vae_config = model_config_params['first_stage_config']
|
|
||||||
|
|
||||||
fp16 = False
|
|
||||||
if "unet_config" in model_config_params:
|
|
||||||
if "params" in model_config_params["unet_config"]:
|
|
||||||
unet_config = model_config_params["unet_config"]["params"]
|
|
||||||
if "use_fp16" in unet_config:
|
|
||||||
fp16 = unet_config.pop("use_fp16")
|
|
||||||
if fp16:
|
|
||||||
unet_config["dtype"] = torch.float16
|
|
||||||
|
|
||||||
noise_aug_config = None
|
|
||||||
if "noise_aug_config" in model_config_params:
|
|
||||||
noise_aug_config = model_config_params["noise_aug_config"]
|
|
||||||
|
|
||||||
model_type = model_base.ModelType.EPS
|
|
||||||
|
|
||||||
if "parameterization" in model_config_params:
|
if "parameterization" in model_config_params:
|
||||||
if model_config_params["parameterization"] == "v":
|
if model_config_params["parameterization"] == "v":
|
||||||
model_type = model_base.ModelType.V_PREDICTION
|
m = model.clone()
|
||||||
|
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.model_sampling.V_PREDICTION):
|
||||||
|
pass
|
||||||
|
m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config))
|
||||||
|
model = m
|
||||||
|
|
||||||
clip = None
|
layer_idx = clip_config.get("params", {}).get("layer_idx", None)
|
||||||
vae = None
|
if layer_idx is not None:
|
||||||
|
clip.clip_layer(layer_idx)
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
return (model, clip, vae)
|
||||||
pass
|
|
||||||
|
|
||||||
if state_dict is None:
|
|
||||||
state_dict = comfy.utils.load_torch_file(ckpt_path)
|
|
||||||
|
|
||||||
class EmptyClass:
|
|
||||||
pass
|
|
||||||
|
|
||||||
model_config = comfy.supported_models_base.BASE({})
|
|
||||||
|
|
||||||
from . import latent_formats
|
|
||||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
|
||||||
model_config.unet_config = model_detection.convert_config(unet_config)
|
|
||||||
|
|
||||||
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
|
||||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
|
||||||
else:
|
|
||||||
model = model_base.BaseModel(model_config, model_type=model_type)
|
|
||||||
|
|
||||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
|
||||||
model.set_inpaint()
|
|
||||||
|
|
||||||
if fp16:
|
|
||||||
model = model.half()
|
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
|
||||||
model = model.to(offload_device)
|
|
||||||
model.load_model_weights(state_dict, "model.diffusion_model.")
|
|
||||||
|
|
||||||
if output_vae:
|
|
||||||
vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
|
|
||||||
vae = VAE(sd=vae_sd, config=vae_config)
|
|
||||||
|
|
||||||
if output_clip:
|
|
||||||
w = WeightsLoader()
|
|
||||||
clip_target = EmptyClass()
|
|
||||||
clip_target.params = clip_config.get("params", {})
|
|
||||||
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
|
||||||
clip_target.clip = sd2_clip.SD2ClipModel
|
|
||||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
|
||||||
w.cond_stage_model = clip.cond_stage_model.clip_h
|
|
||||||
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
|
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
|
||||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
|
||||||
w.cond_stage_model = clip.cond_stage_model.clip_l
|
|
||||||
load_clip_weights(w, state_dict)
|
|
||||||
|
|
||||||
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||||
@ -507,10 +490,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
model_patcher = None
|
model_patcher = None
|
||||||
clip_target = None
|
clip_target = None
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
|
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
@ -525,8 +509,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, diffusion_model_prefix)
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||||
@ -534,14 +518,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
vae = VAE(sd=vae_sd)
|
vae = VAE(sd=vae_sd)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
clip_target = model_config.clip_target()
|
clip_target = model_config.clip_target(state_dict=sd)
|
||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
logging.warning("clip missing: {}".format(m))
|
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||||
|
if len(m_filter) > 0:
|
||||||
|
logging.warning("clip missing: {}".format(m))
|
||||||
|
else:
|
||||||
|
logging.debug("clip missing: {}".format(m))
|
||||||
|
|
||||||
if len(u) > 0:
|
if len(u) > 0:
|
||||||
logging.debug("clip unexpected {}:".format(u))
|
logging.debug("clip unexpected {}:".format(u))
|
||||||
@ -613,7 +601,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||||||
load_models.append(clip.load_model())
|
load_models.append(clip.load_model())
|
||||||
clip_sd = clip.get_sd()
|
clip_sd = clip.get_sd()
|
||||||
|
|
||||||
model_management.load_models_gpu(load_models)
|
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||||
for k in extra_keys:
|
for k in extra_keys:
|
||||||
|
|||||||
@ -68,7 +68,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
]
|
]
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
|
return_projected_pooled=True): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
|
|
||||||
@ -90,6 +91,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
self.enable_attention_masks = enable_attention_masks
|
self.enable_attention_masks = enable_attention_masks
|
||||||
|
self.zero_out_masked = zero_out_masked
|
||||||
|
|
||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
self.return_projected_pooled = return_projected_pooled
|
self.return_projected_pooled = return_projected_pooled
|
||||||
@ -168,20 +170,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
attention_mask = None
|
attention_mask = None
|
||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
attention_mask = torch.zeros_like(tokens)
|
||||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
end_token = self.special_tokens.get("end", -1)
|
||||||
for x in range(attention_mask.shape[0]):
|
for x in range(attention_mask.shape[0]):
|
||||||
for y in range(attention_mask.shape[1]):
|
for y in range(attention_mask.shape[1]):
|
||||||
attention_mask[x, y] = 1
|
attention_mask[x, y] = 1
|
||||||
if tokens[x, y] == max_token:
|
if tokens[x, y] == end_token:
|
||||||
break
|
break
|
||||||
|
|
||||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
z = outputs[0]
|
z = outputs[0].float()
|
||||||
else:
|
else:
|
||||||
z = outputs[1]
|
z = outputs[1].float()
|
||||||
|
|
||||||
|
if self.zero_out_masked and attention_mask is not None:
|
||||||
|
z *= attention_mask.unsqueeze(-1).float()
|
||||||
|
|
||||||
pooled_output = None
|
pooled_output = None
|
||||||
if len(outputs) >= 3:
|
if len(outputs) >= 3:
|
||||||
@ -190,7 +195,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
elif outputs[2] is not None:
|
elif outputs[2] is not None:
|
||||||
pooled_output = outputs[2].float()
|
pooled_output = outputs[2].float()
|
||||||
|
|
||||||
return z.float(), pooled_output
|
return z, pooled_output
|
||||||
|
|
||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
@ -506,6 +511,10 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||||
|
|
||||||
|
self.dtypes = set()
|
||||||
|
if dtype is not None:
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
getattr(self, self.clip).set_clip_options(options)
|
getattr(self, self.clip).set_clip_options(options)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||||
|
|||||||
150
comfy/sd3_clip.py
Normal file
150
comfy/sd3_clip.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
from comfy import sdxl_clip
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
import comfy.t5
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import comfy.model_management
|
||||||
|
import logging
|
||||||
|
|
||||||
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||||
|
|
||||||
|
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
class SDT5XXLModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SD3Tokenizer:
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||||
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||||
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
|
out = {}
|
||||||
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||||
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
class SD3ClipModel(torch.nn.Module):
|
||||||
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.dtypes = set()
|
||||||
|
if clip_l:
|
||||||
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
else:
|
||||||
|
self.clip_l = None
|
||||||
|
|
||||||
|
if clip_g:
|
||||||
|
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
else:
|
||||||
|
self.clip_g = None
|
||||||
|
|
||||||
|
if t5:
|
||||||
|
if dtype_t5 is None:
|
||||||
|
dtype_t5 = dtype
|
||||||
|
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
|
||||||
|
dtype_t5 = dtype
|
||||||
|
|
||||||
|
if not comfy.model_management.supports_cast(device, dtype_t5):
|
||||||
|
dtype_t5 = dtype
|
||||||
|
|
||||||
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||||
|
self.dtypes.add(dtype_t5)
|
||||||
|
else:
|
||||||
|
self.t5xxl = None
|
||||||
|
|
||||||
|
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
if self.clip_l is not None:
|
||||||
|
self.clip_l.set_clip_options(options)
|
||||||
|
if self.clip_g is not None:
|
||||||
|
self.clip_g.set_clip_options(options)
|
||||||
|
if self.t5xxl is not None:
|
||||||
|
self.t5xxl.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
if self.clip_l is not None:
|
||||||
|
self.clip_l.reset_clip_options()
|
||||||
|
if self.clip_g is not None:
|
||||||
|
self.clip_g.reset_clip_options()
|
||||||
|
if self.t5xxl is not None:
|
||||||
|
self.t5xxl.reset_clip_options()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
|
token_weight_pairs_g = token_weight_pairs["g"]
|
||||||
|
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
|
||||||
|
lg_out = None
|
||||||
|
pooled = None
|
||||||
|
out = None
|
||||||
|
|
||||||
|
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||||
|
if self.clip_l is not None:
|
||||||
|
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
|
else:
|
||||||
|
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if self.clip_g is not None:
|
||||||
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
|
if lg_out is not None:
|
||||||
|
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
||||||
|
else:
|
||||||
|
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||||
|
else:
|
||||||
|
g_out = None
|
||||||
|
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if lg_out is not None:
|
||||||
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||||
|
out = lg_out
|
||||||
|
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
|
||||||
|
if self.t5xxl is not None:
|
||||||
|
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||||
|
if lg_out is not None:
|
||||||
|
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||||
|
else:
|
||||||
|
out = t5_out
|
||||||
|
|
||||||
|
if out is None:
|
||||||
|
out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if pooled is None:
|
||||||
|
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
return out, pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_g.load_sd(sd)
|
||||||
|
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_l.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||||
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None):
|
||||||
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||||
|
return SD3ClipModel_
|
||||||
@ -39,6 +39,7 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||||
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
self.clip_l.set_clip_options(options)
|
self.clip_l.set_clip_options(options)
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from . import utils
|
|||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
|
from . import sd3_clip
|
||||||
|
from . import sa_t5
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -53,7 +55,7 @@ class SD15(supported_models_base.BASE):
|
|||||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||||
|
|
||||||
class SD20(supported_models_base.BASE):
|
class SD20(supported_models_base.BASE):
|
||||||
@ -65,6 +67,12 @@ class SD20(supported_models_base.BASE):
|
|||||||
"use_temporal_attention": False,
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {
|
||||||
|
"num_heads": -1,
|
||||||
|
"num_head_channels": 64,
|
||||||
|
"attn_precision": torch.float32,
|
||||||
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
def model_type(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
@ -90,7 +98,7 @@ class SD20(supported_models_base.BASE):
|
|||||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||||
|
|
||||||
class SD21UnclipL(SD20):
|
class SD21UnclipL(SD20):
|
||||||
@ -152,7 +160,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
return state_dict_g
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||||
|
|
||||||
class SDXL(supported_models_base.BASE):
|
class SDXL(supported_models_base.BASE):
|
||||||
@ -221,7 +229,7 @@ class SDXL(supported_models_base.BASE):
|
|||||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
return state_dict_g
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||||
|
|
||||||
class SSD1B(SDXL):
|
class SSD1B(SDXL):
|
||||||
@ -276,6 +284,12 @@ class SVD_img2vid(supported_models_base.BASE):
|
|||||||
"use_temporal_resblock": True
|
"use_temporal_resblock": True
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {
|
||||||
|
"num_heads": -1,
|
||||||
|
"num_head_channels": 64,
|
||||||
|
"attn_precision": torch.float32,
|
||||||
|
}
|
||||||
|
|
||||||
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
@ -286,7 +300,7 @@ class SVD_img2vid(supported_models_base.BASE):
|
|||||||
out = model_base.SVD_img2vid(self, device=device)
|
out = model_base.SVD_img2vid(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class SV3D_u(SVD_img2vid):
|
class SV3D_u(SVD_img2vid):
|
||||||
@ -352,7 +366,7 @@ class Stable_Zero123(supported_models_base.BASE):
|
|||||||
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class SD_X4Upscaler(SD20):
|
class SD_X4Upscaler(SD20):
|
||||||
@ -426,7 +440,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
|||||||
out = model_base.StableCascade_C(self, device=device)
|
out = model_base.StableCascade_C(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||||
|
|
||||||
class Stable_Cascade_B(Stable_Cascade_C):
|
class Stable_Cascade_B(Stable_Cascade_C):
|
||||||
@ -476,6 +490,70 @@ class SDXL_instructpix2pix(SDXL):
|
|||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
class SD3(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"in_channels": 16,
|
||||||
|
"pos_embed_scaling_factor": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.SD3
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.SD3(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
clip_l = False
|
||||||
|
clip_g = False
|
||||||
|
t5 = False
|
||||||
|
dtype_t5 = None
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
|
clip_l = True
|
||||||
|
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||||
|
clip_g = True
|
||||||
|
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||||
|
if t5_key in state_dict:
|
||||||
|
t5 = True
|
||||||
|
dtype_t5 = state_dict[t5_key].dtype
|
||||||
|
|
||||||
|
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||||
|
|
||||||
|
class StableAudio(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"audio_model": "dit1.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.StableAudio1
|
||||||
|
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
vae_key_prefix = ["pretransform.model."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
seconds_start_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_start.": ""}, filter_keys=True)
|
||||||
|
seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
|
||||||
|
return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
|
||||||
|
state_dict.pop(k)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
||||||
|
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
231
comfy/t5.py
Normal file
231
comfy/t5.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
class T5LayerNorm(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
variance = x.pow(2).mean(-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||||
|
|
||||||
|
class T5DenseActDense(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.relu(self.wi(x))
|
||||||
|
# x = self.dropout(x)
|
||||||
|
x = self.wo(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class T5DenseGatedActDense(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||||
|
hidden_linear = self.wi_1(x)
|
||||||
|
x = hidden_gelu * hidden_linear
|
||||||
|
# x = self.dropout(x)
|
||||||
|
x = self.wo(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class T5LayerFF(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
if ff_activation == "gelu_pytorch_tanh":
|
||||||
|
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations)
|
||||||
|
elif ff_activation == "relu":
|
||||||
|
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations)
|
||||||
|
|
||||||
|
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
forwarded_states = self.layer_norm(x)
|
||||||
|
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||||
|
# x = x + self.dropout(forwarded_states)
|
||||||
|
x += forwarded_states
|
||||||
|
return x
|
||||||
|
|
||||||
|
class T5Attention(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||||
|
self.q = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.k = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.v = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.o = operations.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.relative_attention_bias = None
|
||||||
|
if relative_attention_bias:
|
||||||
|
self.relative_attention_num_buckets = 32
|
||||||
|
self.relative_attention_max_distance = 128
|
||||||
|
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||||
|
"""
|
||||||
|
Adapted from Mesh Tensorflow:
|
||||||
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||||
|
|
||||||
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||||
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||||
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||||
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||||
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||||
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relative_position: an int32 Tensor
|
||||||
|
bidirectional: a boolean - whether the attention is bidirectional
|
||||||
|
num_buckets: an integer
|
||||||
|
max_distance: an integer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||||
|
"""
|
||||||
|
relative_buckets = 0
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||||
|
relative_position = torch.abs(relative_position)
|
||||||
|
else:
|
||||||
|
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||||
|
# now relative_position is in the range [0, inf)
|
||||||
|
|
||||||
|
# half of the buckets are for exact increments in positions
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
relative_position_if_large = max_exact + (
|
||||||
|
torch.log(relative_position.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(torch.long)
|
||||||
|
relative_position_if_large = torch.min(
|
||||||
|
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
def compute_bias(self, query_length, key_length, device):
|
||||||
|
"""Compute binned relative position bias"""
|
||||||
|
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||||
|
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||||
|
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||||
|
relative_position_bucket = self._relative_position_bucket(
|
||||||
|
relative_position, # shape (query_length, key_length)
|
||||||
|
bidirectional=True,
|
||||||
|
num_buckets=self.relative_attention_num_buckets,
|
||||||
|
max_distance=self.relative_attention_max_distance,
|
||||||
|
)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||||
|
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||||
|
q = self.q(x)
|
||||||
|
k = self.k(x)
|
||||||
|
v = self.v(x)
|
||||||
|
if self.relative_attention_bias is not None:
|
||||||
|
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||||
|
|
||||||
|
if past_bias is not None:
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask + past_bias
|
||||||
|
else:
|
||||||
|
mask = past_bias
|
||||||
|
|
||||||
|
out = optimized_attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
||||||
|
return self.o(out), past_bias
|
||||||
|
|
||||||
|
class T5LayerSelfAttention(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations)
|
||||||
|
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||||
|
normed_hidden_states = self.layer_norm(x)
|
||||||
|
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
|
||||||
|
# x = x + self.dropout(attention_output)
|
||||||
|
x += output
|
||||||
|
return x, past_bias
|
||||||
|
|
||||||
|
class T5Block(torch.nn.Module):
|
||||||
|
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = torch.nn.ModuleList()
|
||||||
|
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
|
||||||
|
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations))
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||||
|
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
|
||||||
|
x = self.layer[-1](x)
|
||||||
|
return x, past_bias
|
||||||
|
|
||||||
|
class T5Stack(torch.nn.Module):
|
||||||
|
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.block = torch.nn.ModuleList(
|
||||||
|
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
||||||
|
)
|
||||||
|
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||||
|
mask = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
|
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||||
|
|
||||||
|
intermediate = None
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
|
||||||
|
past_bias = None
|
||||||
|
for i, l in enumerate(self.block):
|
||||||
|
x, past_bias = l(x, mask, past_bias, optimized_attention)
|
||||||
|
if i == intermediate_output:
|
||||||
|
intermediate = x.clone()
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if intermediate is not None and final_layer_norm_intermediate:
|
||||||
|
intermediate = self.final_layer_norm(intermediate)
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
class T5(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = config_dict["num_layers"]
|
||||||
|
model_dim = config_dict["d_model"]
|
||||||
|
|
||||||
|
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.shared
|
||||||
|
|
||||||
|
def set_input_embeddings(self, embeddings):
|
||||||
|
self.shared = embeddings
|
||||||
|
|
||||||
|
def forward(self, input_ids, *args, **kwargs):
|
||||||
|
x = self.shared(input_ids)
|
||||||
|
return self.encoder(x, *args, **kwargs)
|
||||||
21
comfy/t5_config_base.json
Normal file
21
comfy/t5_config_base.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"d_ff": 3072,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 768,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"dense_act_fn": "relu",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 12,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 12,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
||||||
21
comfy/t5_config_xxl.json
Normal file
21
comfy/t5_config_xxl.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"d_ff": 10240,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 4096,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"dense_act_fn": "gelu_pytorch_tanh",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 24,
|
||||||
|
"num_heads": 64,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
||||||
125
comfy/t5_tokenizer/special_tokens_map.json
Normal file
125
comfy/t5_tokenizer/special_tokens_map.json
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
{
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<extra_id_0>",
|
||||||
|
"<extra_id_1>",
|
||||||
|
"<extra_id_2>",
|
||||||
|
"<extra_id_3>",
|
||||||
|
"<extra_id_4>",
|
||||||
|
"<extra_id_5>",
|
||||||
|
"<extra_id_6>",
|
||||||
|
"<extra_id_7>",
|
||||||
|
"<extra_id_8>",
|
||||||
|
"<extra_id_9>",
|
||||||
|
"<extra_id_10>",
|
||||||
|
"<extra_id_11>",
|
||||||
|
"<extra_id_12>",
|
||||||
|
"<extra_id_13>",
|
||||||
|
"<extra_id_14>",
|
||||||
|
"<extra_id_15>",
|
||||||
|
"<extra_id_16>",
|
||||||
|
"<extra_id_17>",
|
||||||
|
"<extra_id_18>",
|
||||||
|
"<extra_id_19>",
|
||||||
|
"<extra_id_20>",
|
||||||
|
"<extra_id_21>",
|
||||||
|
"<extra_id_22>",
|
||||||
|
"<extra_id_23>",
|
||||||
|
"<extra_id_24>",
|
||||||
|
"<extra_id_25>",
|
||||||
|
"<extra_id_26>",
|
||||||
|
"<extra_id_27>",
|
||||||
|
"<extra_id_28>",
|
||||||
|
"<extra_id_29>",
|
||||||
|
"<extra_id_30>",
|
||||||
|
"<extra_id_31>",
|
||||||
|
"<extra_id_32>",
|
||||||
|
"<extra_id_33>",
|
||||||
|
"<extra_id_34>",
|
||||||
|
"<extra_id_35>",
|
||||||
|
"<extra_id_36>",
|
||||||
|
"<extra_id_37>",
|
||||||
|
"<extra_id_38>",
|
||||||
|
"<extra_id_39>",
|
||||||
|
"<extra_id_40>",
|
||||||
|
"<extra_id_41>",
|
||||||
|
"<extra_id_42>",
|
||||||
|
"<extra_id_43>",
|
||||||
|
"<extra_id_44>",
|
||||||
|
"<extra_id_45>",
|
||||||
|
"<extra_id_46>",
|
||||||
|
"<extra_id_47>",
|
||||||
|
"<extra_id_48>",
|
||||||
|
"<extra_id_49>",
|
||||||
|
"<extra_id_50>",
|
||||||
|
"<extra_id_51>",
|
||||||
|
"<extra_id_52>",
|
||||||
|
"<extra_id_53>",
|
||||||
|
"<extra_id_54>",
|
||||||
|
"<extra_id_55>",
|
||||||
|
"<extra_id_56>",
|
||||||
|
"<extra_id_57>",
|
||||||
|
"<extra_id_58>",
|
||||||
|
"<extra_id_59>",
|
||||||
|
"<extra_id_60>",
|
||||||
|
"<extra_id_61>",
|
||||||
|
"<extra_id_62>",
|
||||||
|
"<extra_id_63>",
|
||||||
|
"<extra_id_64>",
|
||||||
|
"<extra_id_65>",
|
||||||
|
"<extra_id_66>",
|
||||||
|
"<extra_id_67>",
|
||||||
|
"<extra_id_68>",
|
||||||
|
"<extra_id_69>",
|
||||||
|
"<extra_id_70>",
|
||||||
|
"<extra_id_71>",
|
||||||
|
"<extra_id_72>",
|
||||||
|
"<extra_id_73>",
|
||||||
|
"<extra_id_74>",
|
||||||
|
"<extra_id_75>",
|
||||||
|
"<extra_id_76>",
|
||||||
|
"<extra_id_77>",
|
||||||
|
"<extra_id_78>",
|
||||||
|
"<extra_id_79>",
|
||||||
|
"<extra_id_80>",
|
||||||
|
"<extra_id_81>",
|
||||||
|
"<extra_id_82>",
|
||||||
|
"<extra_id_83>",
|
||||||
|
"<extra_id_84>",
|
||||||
|
"<extra_id_85>",
|
||||||
|
"<extra_id_86>",
|
||||||
|
"<extra_id_87>",
|
||||||
|
"<extra_id_88>",
|
||||||
|
"<extra_id_89>",
|
||||||
|
"<extra_id_90>",
|
||||||
|
"<extra_id_91>",
|
||||||
|
"<extra_id_92>",
|
||||||
|
"<extra_id_93>",
|
||||||
|
"<extra_id_94>",
|
||||||
|
"<extra_id_95>",
|
||||||
|
"<extra_id_96>",
|
||||||
|
"<extra_id_97>",
|
||||||
|
"<extra_id_98>",
|
||||||
|
"<extra_id_99>"
|
||||||
|
],
|
||||||
|
"eos_token": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"pad_token": {
|
||||||
|
"content": "<pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
||||||
129428
comfy/t5_tokenizer/tokenizer.json
Normal file
129428
comfy/t5_tokenizer/tokenizer.json
Normal file
File diff suppressed because one or more lines are too long
939
comfy/t5_tokenizer/tokenizer_config.json
Normal file
939
comfy/t5_tokenizer/tokenizer_config.json
Normal file
@ -0,0 +1,939 @@
|
|||||||
|
{
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "<pad>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32000": {
|
||||||
|
"content": "<extra_id_99>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32001": {
|
||||||
|
"content": "<extra_id_98>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32002": {
|
||||||
|
"content": "<extra_id_97>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32003": {
|
||||||
|
"content": "<extra_id_96>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32004": {
|
||||||
|
"content": "<extra_id_95>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32005": {
|
||||||
|
"content": "<extra_id_94>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32006": {
|
||||||
|
"content": "<extra_id_93>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32007": {
|
||||||
|
"content": "<extra_id_92>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32008": {
|
||||||
|
"content": "<extra_id_91>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32009": {
|
||||||
|
"content": "<extra_id_90>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32010": {
|
||||||
|
"content": "<extra_id_89>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32011": {
|
||||||
|
"content": "<extra_id_88>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32012": {
|
||||||
|
"content": "<extra_id_87>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32013": {
|
||||||
|
"content": "<extra_id_86>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32014": {
|
||||||
|
"content": "<extra_id_85>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32015": {
|
||||||
|
"content": "<extra_id_84>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32016": {
|
||||||
|
"content": "<extra_id_83>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32017": {
|
||||||
|
"content": "<extra_id_82>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32018": {
|
||||||
|
"content": "<extra_id_81>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32019": {
|
||||||
|
"content": "<extra_id_80>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32020": {
|
||||||
|
"content": "<extra_id_79>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32021": {
|
||||||
|
"content": "<extra_id_78>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32022": {
|
||||||
|
"content": "<extra_id_77>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32023": {
|
||||||
|
"content": "<extra_id_76>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32024": {
|
||||||
|
"content": "<extra_id_75>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32025": {
|
||||||
|
"content": "<extra_id_74>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32026": {
|
||||||
|
"content": "<extra_id_73>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32027": {
|
||||||
|
"content": "<extra_id_72>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32028": {
|
||||||
|
"content": "<extra_id_71>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32029": {
|
||||||
|
"content": "<extra_id_70>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32030": {
|
||||||
|
"content": "<extra_id_69>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32031": {
|
||||||
|
"content": "<extra_id_68>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32032": {
|
||||||
|
"content": "<extra_id_67>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32033": {
|
||||||
|
"content": "<extra_id_66>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32034": {
|
||||||
|
"content": "<extra_id_65>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32035": {
|
||||||
|
"content": "<extra_id_64>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32036": {
|
||||||
|
"content": "<extra_id_63>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32037": {
|
||||||
|
"content": "<extra_id_62>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32038": {
|
||||||
|
"content": "<extra_id_61>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32039": {
|
||||||
|
"content": "<extra_id_60>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32040": {
|
||||||
|
"content": "<extra_id_59>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32041": {
|
||||||
|
"content": "<extra_id_58>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32042": {
|
||||||
|
"content": "<extra_id_57>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32043": {
|
||||||
|
"content": "<extra_id_56>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32044": {
|
||||||
|
"content": "<extra_id_55>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32045": {
|
||||||
|
"content": "<extra_id_54>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32046": {
|
||||||
|
"content": "<extra_id_53>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32047": {
|
||||||
|
"content": "<extra_id_52>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32048": {
|
||||||
|
"content": "<extra_id_51>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32049": {
|
||||||
|
"content": "<extra_id_50>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32050": {
|
||||||
|
"content": "<extra_id_49>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32051": {
|
||||||
|
"content": "<extra_id_48>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32052": {
|
||||||
|
"content": "<extra_id_47>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32053": {
|
||||||
|
"content": "<extra_id_46>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32054": {
|
||||||
|
"content": "<extra_id_45>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32055": {
|
||||||
|
"content": "<extra_id_44>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32056": {
|
||||||
|
"content": "<extra_id_43>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32057": {
|
||||||
|
"content": "<extra_id_42>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32058": {
|
||||||
|
"content": "<extra_id_41>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32059": {
|
||||||
|
"content": "<extra_id_40>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32060": {
|
||||||
|
"content": "<extra_id_39>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32061": {
|
||||||
|
"content": "<extra_id_38>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32062": {
|
||||||
|
"content": "<extra_id_37>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32063": {
|
||||||
|
"content": "<extra_id_36>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32064": {
|
||||||
|
"content": "<extra_id_35>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32065": {
|
||||||
|
"content": "<extra_id_34>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32066": {
|
||||||
|
"content": "<extra_id_33>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32067": {
|
||||||
|
"content": "<extra_id_32>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32068": {
|
||||||
|
"content": "<extra_id_31>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32069": {
|
||||||
|
"content": "<extra_id_30>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32070": {
|
||||||
|
"content": "<extra_id_29>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32071": {
|
||||||
|
"content": "<extra_id_28>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32072": {
|
||||||
|
"content": "<extra_id_27>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32073": {
|
||||||
|
"content": "<extra_id_26>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32074": {
|
||||||
|
"content": "<extra_id_25>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32075": {
|
||||||
|
"content": "<extra_id_24>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32076": {
|
||||||
|
"content": "<extra_id_23>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32077": {
|
||||||
|
"content": "<extra_id_22>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32078": {
|
||||||
|
"content": "<extra_id_21>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32079": {
|
||||||
|
"content": "<extra_id_20>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32080": {
|
||||||
|
"content": "<extra_id_19>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32081": {
|
||||||
|
"content": "<extra_id_18>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32082": {
|
||||||
|
"content": "<extra_id_17>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32083": {
|
||||||
|
"content": "<extra_id_16>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32084": {
|
||||||
|
"content": "<extra_id_15>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32085": {
|
||||||
|
"content": "<extra_id_14>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32086": {
|
||||||
|
"content": "<extra_id_13>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32087": {
|
||||||
|
"content": "<extra_id_12>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32088": {
|
||||||
|
"content": "<extra_id_11>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32089": {
|
||||||
|
"content": "<extra_id_10>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32090": {
|
||||||
|
"content": "<extra_id_9>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32091": {
|
||||||
|
"content": "<extra_id_8>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32092": {
|
||||||
|
"content": "<extra_id_7>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32093": {
|
||||||
|
"content": "<extra_id_6>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32094": {
|
||||||
|
"content": "<extra_id_5>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32095": {
|
||||||
|
"content": "<extra_id_4>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32096": {
|
||||||
|
"content": "<extra_id_3>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32097": {
|
||||||
|
"content": "<extra_id_2>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32098": {
|
||||||
|
"content": "<extra_id_1>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"32099": {
|
||||||
|
"content": "<extra_id_0>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<extra_id_0>",
|
||||||
|
"<extra_id_1>",
|
||||||
|
"<extra_id_2>",
|
||||||
|
"<extra_id_3>",
|
||||||
|
"<extra_id_4>",
|
||||||
|
"<extra_id_5>",
|
||||||
|
"<extra_id_6>",
|
||||||
|
"<extra_id_7>",
|
||||||
|
"<extra_id_8>",
|
||||||
|
"<extra_id_9>",
|
||||||
|
"<extra_id_10>",
|
||||||
|
"<extra_id_11>",
|
||||||
|
"<extra_id_12>",
|
||||||
|
"<extra_id_13>",
|
||||||
|
"<extra_id_14>",
|
||||||
|
"<extra_id_15>",
|
||||||
|
"<extra_id_16>",
|
||||||
|
"<extra_id_17>",
|
||||||
|
"<extra_id_18>",
|
||||||
|
"<extra_id_19>",
|
||||||
|
"<extra_id_20>",
|
||||||
|
"<extra_id_21>",
|
||||||
|
"<extra_id_22>",
|
||||||
|
"<extra_id_23>",
|
||||||
|
"<extra_id_24>",
|
||||||
|
"<extra_id_25>",
|
||||||
|
"<extra_id_26>",
|
||||||
|
"<extra_id_27>",
|
||||||
|
"<extra_id_28>",
|
||||||
|
"<extra_id_29>",
|
||||||
|
"<extra_id_30>",
|
||||||
|
"<extra_id_31>",
|
||||||
|
"<extra_id_32>",
|
||||||
|
"<extra_id_33>",
|
||||||
|
"<extra_id_34>",
|
||||||
|
"<extra_id_35>",
|
||||||
|
"<extra_id_36>",
|
||||||
|
"<extra_id_37>",
|
||||||
|
"<extra_id_38>",
|
||||||
|
"<extra_id_39>",
|
||||||
|
"<extra_id_40>",
|
||||||
|
"<extra_id_41>",
|
||||||
|
"<extra_id_42>",
|
||||||
|
"<extra_id_43>",
|
||||||
|
"<extra_id_44>",
|
||||||
|
"<extra_id_45>",
|
||||||
|
"<extra_id_46>",
|
||||||
|
"<extra_id_47>",
|
||||||
|
"<extra_id_48>",
|
||||||
|
"<extra_id_49>",
|
||||||
|
"<extra_id_50>",
|
||||||
|
"<extra_id_51>",
|
||||||
|
"<extra_id_52>",
|
||||||
|
"<extra_id_53>",
|
||||||
|
"<extra_id_54>",
|
||||||
|
"<extra_id_55>",
|
||||||
|
"<extra_id_56>",
|
||||||
|
"<extra_id_57>",
|
||||||
|
"<extra_id_58>",
|
||||||
|
"<extra_id_59>",
|
||||||
|
"<extra_id_60>",
|
||||||
|
"<extra_id_61>",
|
||||||
|
"<extra_id_62>",
|
||||||
|
"<extra_id_63>",
|
||||||
|
"<extra_id_64>",
|
||||||
|
"<extra_id_65>",
|
||||||
|
"<extra_id_66>",
|
||||||
|
"<extra_id_67>",
|
||||||
|
"<extra_id_68>",
|
||||||
|
"<extra_id_69>",
|
||||||
|
"<extra_id_70>",
|
||||||
|
"<extra_id_71>",
|
||||||
|
"<extra_id_72>",
|
||||||
|
"<extra_id_73>",
|
||||||
|
"<extra_id_74>",
|
||||||
|
"<extra_id_75>",
|
||||||
|
"<extra_id_76>",
|
||||||
|
"<extra_id_77>",
|
||||||
|
"<extra_id_78>",
|
||||||
|
"<extra_id_79>",
|
||||||
|
"<extra_id_80>",
|
||||||
|
"<extra_id_81>",
|
||||||
|
"<extra_id_82>",
|
||||||
|
"<extra_id_83>",
|
||||||
|
"<extra_id_84>",
|
||||||
|
"<extra_id_85>",
|
||||||
|
"<extra_id_86>",
|
||||||
|
"<extra_id_87>",
|
||||||
|
"<extra_id_88>",
|
||||||
|
"<extra_id_89>",
|
||||||
|
"<extra_id_90>",
|
||||||
|
"<extra_id_91>",
|
||||||
|
"<extra_id_92>",
|
||||||
|
"<extra_id_93>",
|
||||||
|
"<extra_id_94>",
|
||||||
|
"<extra_id_95>",
|
||||||
|
"<extra_id_96>",
|
||||||
|
"<extra_id_97>",
|
||||||
|
"<extra_id_98>",
|
||||||
|
"<extra_id_99>"
|
||||||
|
],
|
||||||
|
"clean_up_tokenization_spaces": true,
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"extra_ids": 100,
|
||||||
|
"legacy": false,
|
||||||
|
"model_max_length": 512,
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"tokenizer_class": "T5Tokenizer",
|
||||||
|
"unk_token": "<unk>"
|
||||||
|
}
|
||||||
@ -25,18 +25,19 @@ class Block(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.fuse(self.conv(x) + self.skip(x))
|
return self.fuse(self.conv(x) + self.skip(x))
|
||||||
|
|
||||||
def Encoder():
|
def Encoder(latent_channels=4):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
conv(3, 64), Block(64, 64),
|
conv(3, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
conv(64, 4),
|
conv(64, latent_channels),
|
||||||
)
|
)
|
||||||
|
|
||||||
def Decoder():
|
|
||||||
|
def Decoder(latent_channels=4):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
Clamp(), conv(4, 64), nn.ReLU(),
|
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
@ -47,12 +48,13 @@ class TAESD(nn.Module):
|
|||||||
latent_magnitude = 3
|
latent_magnitude = 3
|
||||||
latent_shift = 0.5
|
latent_shift = 0.5
|
||||||
|
|
||||||
def __init__(self, encoder_path=None, decoder_path=None):
|
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
||||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.taesd_encoder = Encoder()
|
self.taesd_encoder = Encoder(latent_channels=latent_channels)
|
||||||
self.taesd_decoder = Decoder()
|
self.taesd_decoder = Decoder(latent_channels=latent_channels)
|
||||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||||
|
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||||
if encoder_path is not None:
|
if encoder_path is not None:
|
||||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||||
if decoder_path is not None:
|
if decoder_path is not None:
|
||||||
@ -69,9 +71,9 @@ class TAESD(nn.Module):
|
|||||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||||
|
|
||||||
def decode(self, x):
|
def decode(self, x):
|
||||||
x_sample = self.taesd_decoder(x * self.vae_scale)
|
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
||||||
x_sample = x_sample.sub(0.5).mul(2)
|
x_sample = x_sample.sub(0.5).mul(2)
|
||||||
return x_sample
|
return x_sample
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
|
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
||||||
|
|||||||
32
comfy/types.py
Normal file
32
comfy/types.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Callable, Protocol, TypedDict, Optional, List
|
||||||
|
|
||||||
|
|
||||||
|
class UnetApplyFunction(Protocol):
|
||||||
|
"""Function signature protocol on comfy.model_base.BaseModel.apply_model"""
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnetApplyConds(TypedDict):
|
||||||
|
"""Optional conditions for unet apply function."""
|
||||||
|
|
||||||
|
c_concat: Optional[torch.Tensor]
|
||||||
|
c_crossattn: Optional[torch.Tensor]
|
||||||
|
control: Optional[torch.Tensor]
|
||||||
|
transformer_options: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class UnetParams(TypedDict):
|
||||||
|
# Tensor of shape [B, C, H, W]
|
||||||
|
input: torch.Tensor
|
||||||
|
# Tensor of shape [B]
|
||||||
|
timestep: torch.Tensor
|
||||||
|
c: UnetApplyConds
|
||||||
|
# List of [0, 1], [0], [1], ...
|
||||||
|
# 0 means conditional, 1 means conditional unconditional
|
||||||
|
cond_or_uncond: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
|
||||||
@ -249,11 +249,11 @@ def unet_to_diffusers(unet_config):
|
|||||||
|
|
||||||
return diffusers_unet_map
|
return diffusers_unet_map
|
||||||
|
|
||||||
def repeat_to_batch_size(tensor, batch_size):
|
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||||
if tensor.shape[0] > batch_size:
|
if tensor.shape[dim] > batch_size:
|
||||||
return tensor[:batch_size]
|
return tensor.narrow(dim, 0, batch_size)
|
||||||
elif tensor.shape[0] < batch_size:
|
elif tensor.shape[dim] < batch_size:
|
||||||
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
|
return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def resize_to_batch_size(tensor, batch_size):
|
def resize_to_batch_size(tensor, batch_size):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2022 Xiangyu Chen
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
@ -1,29 +0,0 @@
|
|||||||
BSD 3-Clause License
|
|
||||||
|
|
||||||
Copyright (c) 2021, Xintao Wang
|
|
||||||
All rights reserved.
|
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions are met:
|
|
||||||
|
|
||||||
1. Redistributions of source code must retain the above copyright notice, this
|
|
||||||
list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
this list of conditions and the following disclaimer in the documentation
|
|
||||||
and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
3. Neither the name of the copyright holder nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived from
|
|
||||||
this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
||||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
||||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
||||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright 2022 Kai Zhang (cskaizhang@gmail.com, https://cszn.github.io/). All rights reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright 2018-2022 BasicSR Authors
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@ -1,121 +0,0 @@
|
|||||||
Creative Commons Legal Code
|
|
||||||
|
|
||||||
CC0 1.0 Universal
|
|
||||||
|
|
||||||
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
|
|
||||||
LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
|
|
||||||
ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
|
|
||||||
INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
|
|
||||||
REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
|
|
||||||
PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
|
|
||||||
THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
|
|
||||||
HEREUNDER.
|
|
||||||
|
|
||||||
Statement of Purpose
|
|
||||||
|
|
||||||
The laws of most jurisdictions throughout the world automatically confer
|
|
||||||
exclusive Copyright and Related Rights (defined below) upon the creator
|
|
||||||
and subsequent owner(s) (each and all, an "owner") of an original work of
|
|
||||||
authorship and/or a database (each, a "Work").
|
|
||||||
|
|
||||||
Certain owners wish to permanently relinquish those rights to a Work for
|
|
||||||
the purpose of contributing to a commons of creative, cultural and
|
|
||||||
scientific works ("Commons") that the public can reliably and without fear
|
|
||||||
of later claims of infringement build upon, modify, incorporate in other
|
|
||||||
works, reuse and redistribute as freely as possible in any form whatsoever
|
|
||||||
and for any purposes, including without limitation commercial purposes.
|
|
||||||
These owners may contribute to the Commons to promote the ideal of a free
|
|
||||||
culture and the further production of creative, cultural and scientific
|
|
||||||
works, or to gain reputation or greater distribution for their Work in
|
|
||||||
part through the use and efforts of others.
|
|
||||||
|
|
||||||
For these and/or other purposes and motivations, and without any
|
|
||||||
expectation of additional consideration or compensation, the person
|
|
||||||
associating CC0 with a Work (the "Affirmer"), to the extent that he or she
|
|
||||||
is an owner of Copyright and Related Rights in the Work, voluntarily
|
|
||||||
elects to apply CC0 to the Work and publicly distribute the Work under its
|
|
||||||
terms, with knowledge of his or her Copyright and Related Rights in the
|
|
||||||
Work and the meaning and intended legal effect of CC0 on those rights.
|
|
||||||
|
|
||||||
1. Copyright and Related Rights. A Work made available under CC0 may be
|
|
||||||
protected by copyright and related or neighboring rights ("Copyright and
|
|
||||||
Related Rights"). Copyright and Related Rights include, but are not
|
|
||||||
limited to, the following:
|
|
||||||
|
|
||||||
i. the right to reproduce, adapt, distribute, perform, display,
|
|
||||||
communicate, and translate a Work;
|
|
||||||
ii. moral rights retained by the original author(s) and/or performer(s);
|
|
||||||
iii. publicity and privacy rights pertaining to a person's image or
|
|
||||||
likeness depicted in a Work;
|
|
||||||
iv. rights protecting against unfair competition in regards to a Work,
|
|
||||||
subject to the limitations in paragraph 4(a), below;
|
|
||||||
v. rights protecting the extraction, dissemination, use and reuse of data
|
|
||||||
in a Work;
|
|
||||||
vi. database rights (such as those arising under Directive 96/9/EC of the
|
|
||||||
European Parliament and of the Council of 11 March 1996 on the legal
|
|
||||||
protection of databases, and under any national implementation
|
|
||||||
thereof, including any amended or successor version of such
|
|
||||||
directive); and
|
|
||||||
vii. other similar, equivalent or corresponding rights throughout the
|
|
||||||
world based on applicable law or treaty, and any national
|
|
||||||
implementations thereof.
|
|
||||||
|
|
||||||
2. Waiver. To the greatest extent permitted by, but not in contravention
|
|
||||||
of, applicable law, Affirmer hereby overtly, fully, permanently,
|
|
||||||
irrevocably and unconditionally waives, abandons, and surrenders all of
|
|
||||||
Affirmer's Copyright and Related Rights and associated claims and causes
|
|
||||||
of action, whether now known or unknown (including existing as well as
|
|
||||||
future claims and causes of action), in the Work (i) in all territories
|
|
||||||
worldwide, (ii) for the maximum duration provided by applicable law or
|
|
||||||
treaty (including future time extensions), (iii) in any current or future
|
|
||||||
medium and for any number of copies, and (iv) for any purpose whatsoever,
|
|
||||||
including without limitation commercial, advertising or promotional
|
|
||||||
purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
|
|
||||||
member of the public at large and to the detriment of Affirmer's heirs and
|
|
||||||
successors, fully intending that such Waiver shall not be subject to
|
|
||||||
revocation, rescission, cancellation, termination, or any other legal or
|
|
||||||
equitable action to disrupt the quiet enjoyment of the Work by the public
|
|
||||||
as contemplated by Affirmer's express Statement of Purpose.
|
|
||||||
|
|
||||||
3. Public License Fallback. Should any part of the Waiver for any reason
|
|
||||||
be judged legally invalid or ineffective under applicable law, then the
|
|
||||||
Waiver shall be preserved to the maximum extent permitted taking into
|
|
||||||
account Affirmer's express Statement of Purpose. In addition, to the
|
|
||||||
extent the Waiver is so judged Affirmer hereby grants to each affected
|
|
||||||
person a royalty-free, non transferable, non sublicensable, non exclusive,
|
|
||||||
irrevocable and unconditional license to exercise Affirmer's Copyright and
|
|
||||||
Related Rights in the Work (i) in all territories worldwide, (ii) for the
|
|
||||||
maximum duration provided by applicable law or treaty (including future
|
|
||||||
time extensions), (iii) in any current or future medium and for any number
|
|
||||||
of copies, and (iv) for any purpose whatsoever, including without
|
|
||||||
limitation commercial, advertising or promotional purposes (the
|
|
||||||
"License"). The License shall be deemed effective as of the date CC0 was
|
|
||||||
applied by Affirmer to the Work. Should any part of the License for any
|
|
||||||
reason be judged legally invalid or ineffective under applicable law, such
|
|
||||||
partial invalidity or ineffectiveness shall not invalidate the remainder
|
|
||||||
of the License, and in such case Affirmer hereby affirms that he or she
|
|
||||||
will not (i) exercise any of his or her remaining Copyright and Related
|
|
||||||
Rights in the Work or (ii) assert any associated claims and causes of
|
|
||||||
action with respect to the Work, in either case contrary to Affirmer's
|
|
||||||
express Statement of Purpose.
|
|
||||||
|
|
||||||
4. Limitations and Disclaimers.
|
|
||||||
|
|
||||||
a. No trademark or patent rights held by Affirmer are waived, abandoned,
|
|
||||||
surrendered, licensed or otherwise affected by this document.
|
|
||||||
b. Affirmer offers the Work as-is and makes no representations or
|
|
||||||
warranties of any kind concerning the Work, express, implied,
|
|
||||||
statutory or otherwise, including without limitation warranties of
|
|
||||||
title, merchantability, fitness for a particular purpose, non
|
|
||||||
infringement, or the absence of latent or other defects, accuracy, or
|
|
||||||
the present or absence of errors, whether or not discoverable, all to
|
|
||||||
the greatest extent permissible under applicable law.
|
|
||||||
c. Affirmer disclaims responsibility for clearing rights of other persons
|
|
||||||
that may apply to the Work or any use thereof, including without
|
|
||||||
limitation any person's Copyright and Related Rights in the Work.
|
|
||||||
Further, Affirmer disclaims responsibility for obtaining any necessary
|
|
||||||
consents, permissions or other rights required for any use of the
|
|
||||||
Work.
|
|
||||||
d. Affirmer understands and acknowledges that Creative Commons is not a
|
|
||||||
party to this document and has no duty or obligation with respect to
|
|
||||||
this CC0 or use of the Work.
|
|
||||||
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [2021] [SwinIR Authors]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [2021] [SwinIR Authors]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [2021] Samsung Research
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@ -1,694 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
"""
|
|
||||||
Model adapted from advimman's lama project: https://github.com/advimman/lama
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Fast Fourier Convolution NeurIPS 2020
|
|
||||||
# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
|
|
||||||
# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torchvision.transforms.functional import InterpolationMode, rotate
|
|
||||||
|
|
||||||
|
|
||||||
class LearnableSpatialTransformWrapper(nn.Module):
|
|
||||||
def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
|
|
||||||
super().__init__()
|
|
||||||
self.impl = impl
|
|
||||||
self.angle = torch.rand(1) * angle_init_range
|
|
||||||
if train_angle:
|
|
||||||
self.angle = nn.Parameter(self.angle, requires_grad=True)
|
|
||||||
self.pad_coef = pad_coef
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if torch.is_tensor(x):
|
|
||||||
return self.inverse_transform(self.impl(self.transform(x)), x)
|
|
||||||
elif isinstance(x, tuple):
|
|
||||||
x_trans = tuple(self.transform(elem) for elem in x)
|
|
||||||
y_trans = self.impl(x_trans)
|
|
||||||
return tuple(
|
|
||||||
self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected input type {type(x)}")
|
|
||||||
|
|
||||||
def transform(self, x):
|
|
||||||
height, width = x.shape[2:]
|
|
||||||
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
|
|
||||||
x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode="reflect")
|
|
||||||
x_padded_rotated = rotate(
|
|
||||||
x_padded, self.angle.to(x_padded), InterpolationMode.BILINEAR, fill=0
|
|
||||||
)
|
|
||||||
|
|
||||||
return x_padded_rotated
|
|
||||||
|
|
||||||
def inverse_transform(self, y_padded_rotated, orig_x):
|
|
||||||
height, width = orig_x.shape[2:]
|
|
||||||
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
|
|
||||||
|
|
||||||
y_padded = rotate(
|
|
||||||
y_padded_rotated,
|
|
||||||
-self.angle.to(y_padded_rotated),
|
|
||||||
InterpolationMode.BILINEAR,
|
|
||||||
fill=0,
|
|
||||||
)
|
|
||||||
y_height, y_width = y_padded.shape[2:]
|
|
||||||
y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class SELayer(nn.Module):
|
|
||||||
def __init__(self, channel, reduction=16):
|
|
||||||
super(SELayer, self).__init__()
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(channel, channel // reduction, bias=False),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(channel // reduction, channel, bias=False),
|
|
||||||
nn.Sigmoid(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, _, _ = x.size()
|
|
||||||
y = self.avg_pool(x).view(b, c)
|
|
||||||
y = self.fc(y).view(b, c, 1, 1)
|
|
||||||
res = x * y.expand_as(x)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class FourierUnit(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
groups=1,
|
|
||||||
spatial_scale_factor=None,
|
|
||||||
spatial_scale_mode="bilinear",
|
|
||||||
spectral_pos_encoding=False,
|
|
||||||
use_se=False,
|
|
||||||
se_kwargs=None,
|
|
||||||
ffc3d=False,
|
|
||||||
fft_norm="ortho",
|
|
||||||
):
|
|
||||||
# bn_layer not used
|
|
||||||
super(FourierUnit, self).__init__()
|
|
||||||
self.groups = groups
|
|
||||||
|
|
||||||
self.conv_layer = torch.nn.Conv2d(
|
|
||||||
in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
|
|
||||||
out_channels=out_channels * 2,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
groups=self.groups,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
|
|
||||||
self.relu = torch.nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
# squeeze and excitation block
|
|
||||||
self.use_se = use_se
|
|
||||||
if use_se:
|
|
||||||
if se_kwargs is None:
|
|
||||||
se_kwargs = {}
|
|
||||||
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
|
|
||||||
|
|
||||||
self.spatial_scale_factor = spatial_scale_factor
|
|
||||||
self.spatial_scale_mode = spatial_scale_mode
|
|
||||||
self.spectral_pos_encoding = spectral_pos_encoding
|
|
||||||
self.ffc3d = ffc3d
|
|
||||||
self.fft_norm = fft_norm
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
half_check = False
|
|
||||||
if x.type() == "torch.cuda.HalfTensor":
|
|
||||||
# half only works on gpu anyway
|
|
||||||
half_check = True
|
|
||||||
|
|
||||||
batch = x.shape[0]
|
|
||||||
|
|
||||||
if self.spatial_scale_factor is not None:
|
|
||||||
orig_size = x.shape[-2:]
|
|
||||||
x = F.interpolate(
|
|
||||||
x,
|
|
||||||
scale_factor=self.spatial_scale_factor,
|
|
||||||
mode=self.spatial_scale_mode,
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# (batch, c, h, w/2+1, 2)
|
|
||||||
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
|
||||||
if half_check == True:
|
|
||||||
ffted = torch.fft.rfftn(
|
|
||||||
x.float(), dim=fft_dim, norm=self.fft_norm
|
|
||||||
) # .type(torch.cuda.HalfTensor)
|
|
||||||
else:
|
|
||||||
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
|
||||||
|
|
||||||
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
|
||||||
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
|
||||||
ffted = ffted.view(
|
|
||||||
(
|
|
||||||
batch,
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
+ ffted.size()[3:]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.spectral_pos_encoding:
|
|
||||||
height, width = ffted.shape[-2:]
|
|
||||||
coords_vert = (
|
|
||||||
torch.linspace(0, 1, height)[None, None, :, None]
|
|
||||||
.expand(batch, 1, height, width)
|
|
||||||
.to(ffted)
|
|
||||||
)
|
|
||||||
coords_hor = (
|
|
||||||
torch.linspace(0, 1, width)[None, None, None, :]
|
|
||||||
.expand(batch, 1, height, width)
|
|
||||||
.to(ffted)
|
|
||||||
)
|
|
||||||
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
|
|
||||||
|
|
||||||
if self.use_se:
|
|
||||||
ffted = self.se(ffted)
|
|
||||||
|
|
||||||
if half_check == True:
|
|
||||||
ffted = self.conv_layer(ffted.half()) # (batch, c*2, h, w/2+1)
|
|
||||||
else:
|
|
||||||
ffted = self.conv_layer(
|
|
||||||
ffted
|
|
||||||
) # .type(torch.cuda.FloatTensor) # (batch, c*2, h, w/2+1)
|
|
||||||
|
|
||||||
ffted = self.relu(self.bn(ffted))
|
|
||||||
# forcing to be always float
|
|
||||||
ffted = ffted.float()
|
|
||||||
|
|
||||||
ffted = (
|
|
||||||
ffted.view(
|
|
||||||
(
|
|
||||||
batch,
|
|
||||||
-1,
|
|
||||||
2,
|
|
||||||
)
|
|
||||||
+ ffted.size()[2:]
|
|
||||||
)
|
|
||||||
.permute(0, 1, 3, 4, 2)
|
|
||||||
.contiguous()
|
|
||||||
) # (batch,c, t, h, w/2+1, 2)
|
|
||||||
|
|
||||||
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
|
||||||
|
|
||||||
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
|
||||||
output = torch.fft.irfftn(
|
|
||||||
ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
if half_check == True:
|
|
||||||
output = output.half()
|
|
||||||
|
|
||||||
if self.spatial_scale_factor is not None:
|
|
||||||
output = F.interpolate(
|
|
||||||
output,
|
|
||||||
size=orig_size,
|
|
||||||
mode=self.spatial_scale_mode,
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class SpectralTransform(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
stride=1,
|
|
||||||
groups=1,
|
|
||||||
enable_lfu=True,
|
|
||||||
separable_fu=False,
|
|
||||||
**fu_kwargs,
|
|
||||||
):
|
|
||||||
# bn_layer not used
|
|
||||||
super(SpectralTransform, self).__init__()
|
|
||||||
self.enable_lfu = enable_lfu
|
|
||||||
if stride == 2:
|
|
||||||
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
|
|
||||||
else:
|
|
||||||
self.downsample = nn.Identity()
|
|
||||||
|
|
||||||
self.stride = stride
|
|
||||||
self.conv1 = nn.Sequential(
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(out_channels // 2),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
)
|
|
||||||
fu_class = FourierUnit
|
|
||||||
self.fu = fu_class(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
|
|
||||||
if self.enable_lfu:
|
|
||||||
self.lfu = fu_class(out_channels // 2, out_channels // 2, groups)
|
|
||||||
self.conv2 = torch.nn.Conv2d(
|
|
||||||
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.downsample(x)
|
|
||||||
x = self.conv1(x)
|
|
||||||
output = self.fu(x)
|
|
||||||
|
|
||||||
if self.enable_lfu:
|
|
||||||
_, c, h, _ = x.shape
|
|
||||||
split_no = 2
|
|
||||||
split_s = h // split_no
|
|
||||||
xs = torch.cat(
|
|
||||||
torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
|
|
||||||
).contiguous()
|
|
||||||
xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
|
|
||||||
xs = self.lfu(xs)
|
|
||||||
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
|
|
||||||
else:
|
|
||||||
xs = 0
|
|
||||||
|
|
||||||
output = self.conv2(x + output + xs)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class FFC(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
ratio_gin,
|
|
||||||
ratio_gout,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
bias=False,
|
|
||||||
enable_lfu=True,
|
|
||||||
padding_type="reflect",
|
|
||||||
gated=False,
|
|
||||||
**spectral_kwargs,
|
|
||||||
):
|
|
||||||
super(FFC, self).__init__()
|
|
||||||
|
|
||||||
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
in_cg = int(in_channels * ratio_gin)
|
|
||||||
in_cl = in_channels - in_cg
|
|
||||||
out_cg = int(out_channels * ratio_gout)
|
|
||||||
out_cl = out_channels - out_cg
|
|
||||||
# groups_g = 1 if groups == 1 else int(groups * ratio_gout)
|
|
||||||
# groups_l = 1 if groups == 1 else groups - groups_g
|
|
||||||
|
|
||||||
self.ratio_gin = ratio_gin
|
|
||||||
self.ratio_gout = ratio_gout
|
|
||||||
self.global_in_num = in_cg
|
|
||||||
|
|
||||||
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
|
|
||||||
self.convl2l = module(
|
|
||||||
in_cl,
|
|
||||||
out_cl,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
padding,
|
|
||||||
dilation,
|
|
||||||
groups,
|
|
||||||
bias,
|
|
||||||
padding_mode=padding_type,
|
|
||||||
)
|
|
||||||
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
|
|
||||||
self.convl2g = module(
|
|
||||||
in_cl,
|
|
||||||
out_cg,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
padding,
|
|
||||||
dilation,
|
|
||||||
groups,
|
|
||||||
bias,
|
|
||||||
padding_mode=padding_type,
|
|
||||||
)
|
|
||||||
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
|
|
||||||
self.convg2l = module(
|
|
||||||
in_cg,
|
|
||||||
out_cl,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
padding,
|
|
||||||
dilation,
|
|
||||||
groups,
|
|
||||||
bias,
|
|
||||||
padding_mode=padding_type,
|
|
||||||
)
|
|
||||||
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
|
|
||||||
self.convg2g = module(
|
|
||||||
in_cg,
|
|
||||||
out_cg,
|
|
||||||
stride,
|
|
||||||
1 if groups == 1 else groups // 2,
|
|
||||||
enable_lfu,
|
|
||||||
**spectral_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gated = gated
|
|
||||||
module = (
|
|
||||||
nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
|
|
||||||
)
|
|
||||||
self.gate = module(in_channels, 2, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_l, x_g = x if type(x) is tuple else (x, 0)
|
|
||||||
out_xl, out_xg = 0, 0
|
|
||||||
|
|
||||||
if self.gated:
|
|
||||||
total_input_parts = [x_l]
|
|
||||||
if torch.is_tensor(x_g):
|
|
||||||
total_input_parts.append(x_g)
|
|
||||||
total_input = torch.cat(total_input_parts, dim=1)
|
|
||||||
|
|
||||||
gates = torch.sigmoid(self.gate(total_input))
|
|
||||||
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
|
|
||||||
else:
|
|
||||||
g2l_gate, l2g_gate = 1, 1
|
|
||||||
|
|
||||||
if self.ratio_gout != 1:
|
|
||||||
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
|
|
||||||
if self.ratio_gout != 0:
|
|
||||||
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
|
|
||||||
|
|
||||||
return out_xl, out_xg
|
|
||||||
|
|
||||||
|
|
||||||
class FFC_BN_ACT(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
ratio_gin,
|
|
||||||
ratio_gout,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
bias=False,
|
|
||||||
norm_layer=nn.BatchNorm2d,
|
|
||||||
activation_layer=nn.Identity,
|
|
||||||
padding_type="reflect",
|
|
||||||
enable_lfu=True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super(FFC_BN_ACT, self).__init__()
|
|
||||||
self.ffc = FFC(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
ratio_gin,
|
|
||||||
ratio_gout,
|
|
||||||
stride,
|
|
||||||
padding,
|
|
||||||
dilation,
|
|
||||||
groups,
|
|
||||||
bias,
|
|
||||||
enable_lfu,
|
|
||||||
padding_type=padding_type,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
|
|
||||||
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
|
|
||||||
global_channels = int(out_channels * ratio_gout)
|
|
||||||
self.bn_l = lnorm(out_channels - global_channels)
|
|
||||||
self.bn_g = gnorm(global_channels)
|
|
||||||
|
|
||||||
lact = nn.Identity if ratio_gout == 1 else activation_layer
|
|
||||||
gact = nn.Identity if ratio_gout == 0 else activation_layer
|
|
||||||
self.act_l = lact(inplace=True)
|
|
||||||
self.act_g = gact(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_l, x_g = self.ffc(x)
|
|
||||||
x_l = self.act_l(self.bn_l(x_l))
|
|
||||||
x_g = self.act_g(self.bn_g(x_g))
|
|
||||||
return x_l, x_g
|
|
||||||
|
|
||||||
|
|
||||||
class FFCResnetBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
padding_type,
|
|
||||||
norm_layer,
|
|
||||||
activation_layer=nn.ReLU,
|
|
||||||
dilation=1,
|
|
||||||
spatial_transform_kwargs=None,
|
|
||||||
inline=False,
|
|
||||||
**conv_kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = FFC_BN_ACT(
|
|
||||||
dim,
|
|
||||||
dim,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=dilation,
|
|
||||||
dilation=dilation,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
activation_layer=activation_layer,
|
|
||||||
padding_type=padding_type,
|
|
||||||
**conv_kwargs,
|
|
||||||
)
|
|
||||||
self.conv2 = FFC_BN_ACT(
|
|
||||||
dim,
|
|
||||||
dim,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=dilation,
|
|
||||||
dilation=dilation,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
activation_layer=activation_layer,
|
|
||||||
padding_type=padding_type,
|
|
||||||
**conv_kwargs,
|
|
||||||
)
|
|
||||||
if spatial_transform_kwargs is not None:
|
|
||||||
self.conv1 = LearnableSpatialTransformWrapper(
|
|
||||||
self.conv1, **spatial_transform_kwargs
|
|
||||||
)
|
|
||||||
self.conv2 = LearnableSpatialTransformWrapper(
|
|
||||||
self.conv2, **spatial_transform_kwargs
|
|
||||||
)
|
|
||||||
self.inline = inline
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.inline:
|
|
||||||
x_l, x_g = (
|
|
||||||
x[:, : -self.conv1.ffc.global_in_num],
|
|
||||||
x[:, -self.conv1.ffc.global_in_num :],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x_l, x_g = x if type(x) is tuple else (x, 0)
|
|
||||||
|
|
||||||
id_l, id_g = x_l, x_g
|
|
||||||
|
|
||||||
x_l, x_g = self.conv1((x_l, x_g))
|
|
||||||
x_l, x_g = self.conv2((x_l, x_g))
|
|
||||||
|
|
||||||
x_l, x_g = id_l + x_l, id_g + x_g
|
|
||||||
out = x_l, x_g
|
|
||||||
if self.inline:
|
|
||||||
out = torch.cat(out, dim=1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ConcatTupleLayer(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
assert isinstance(x, tuple)
|
|
||||||
x_l, x_g = x
|
|
||||||
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
|
|
||||||
if not torch.is_tensor(x_g):
|
|
||||||
return x_l
|
|
||||||
return torch.cat(x, dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
class FFCResNetGenerator(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_nc,
|
|
||||||
output_nc,
|
|
||||||
ngf=64,
|
|
||||||
n_downsampling=3,
|
|
||||||
n_blocks=18,
|
|
||||||
norm_layer=nn.BatchNorm2d,
|
|
||||||
padding_type="reflect",
|
|
||||||
activation_layer=nn.ReLU,
|
|
||||||
up_norm_layer=nn.BatchNorm2d,
|
|
||||||
up_activation=nn.ReLU(True),
|
|
||||||
init_conv_kwargs={},
|
|
||||||
downsample_conv_kwargs={},
|
|
||||||
resnet_conv_kwargs={},
|
|
||||||
spatial_transform_layers=None,
|
|
||||||
spatial_transform_kwargs={},
|
|
||||||
max_features=1024,
|
|
||||||
out_ffc=False,
|
|
||||||
out_ffc_kwargs={},
|
|
||||||
):
|
|
||||||
assert n_blocks >= 0
|
|
||||||
super().__init__()
|
|
||||||
"""
|
|
||||||
init_conv_kwargs = {'ratio_gin': 0, 'ratio_gout': 0, 'enable_lfu': False}
|
|
||||||
downsample_conv_kwargs = {'ratio_gin': '${generator.init_conv_kwargs.ratio_gout}', 'ratio_gout': '${generator.downsample_conv_kwargs.ratio_gin}', 'enable_lfu': False}
|
|
||||||
resnet_conv_kwargs = {'ratio_gin': 0.75, 'ratio_gout': '${generator.resnet_conv_kwargs.ratio_gin}', 'enable_lfu': False}
|
|
||||||
spatial_transform_kwargs = {}
|
|
||||||
out_ffc_kwargs = {}
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
print(input_nc, output_nc, ngf, n_downsampling, n_blocks, norm_layer,
|
|
||||||
padding_type, activation_layer,
|
|
||||||
up_norm_layer, up_activation,
|
|
||||||
spatial_transform_layers,
|
|
||||||
add_out_act, max_features, out_ffc, file=sys.stderr)
|
|
||||||
|
|
||||||
4 3 64 3 18 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
|
|
||||||
reflect <class 'torch.nn.modules.activation.ReLU'>
|
|
||||||
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
|
|
||||||
ReLU(inplace=True)
|
|
||||||
None sigmoid 1024 False
|
|
||||||
"""
|
|
||||||
init_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
|
|
||||||
downsample_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
|
|
||||||
resnet_conv_kwargs = {
|
|
||||||
"ratio_gin": 0.75,
|
|
||||||
"ratio_gout": 0.75,
|
|
||||||
"enable_lfu": False,
|
|
||||||
}
|
|
||||||
spatial_transform_kwargs = {}
|
|
||||||
out_ffc_kwargs = {}
|
|
||||||
|
|
||||||
model = [
|
|
||||||
nn.ReflectionPad2d(3),
|
|
||||||
FFC_BN_ACT(
|
|
||||||
input_nc,
|
|
||||||
ngf,
|
|
||||||
kernel_size=7,
|
|
||||||
padding=0,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
activation_layer=activation_layer,
|
|
||||||
**init_conv_kwargs,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
### downsample
|
|
||||||
for i in range(n_downsampling):
|
|
||||||
mult = 2**i
|
|
||||||
if i == n_downsampling - 1:
|
|
||||||
cur_conv_kwargs = dict(downsample_conv_kwargs)
|
|
||||||
cur_conv_kwargs["ratio_gout"] = resnet_conv_kwargs.get("ratio_gin", 0)
|
|
||||||
else:
|
|
||||||
cur_conv_kwargs = downsample_conv_kwargs
|
|
||||||
model += [
|
|
||||||
FFC_BN_ACT(
|
|
||||||
min(max_features, ngf * mult),
|
|
||||||
min(max_features, ngf * mult * 2),
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
activation_layer=activation_layer,
|
|
||||||
**cur_conv_kwargs,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
mult = 2**n_downsampling
|
|
||||||
feats_num_bottleneck = min(max_features, ngf * mult)
|
|
||||||
|
|
||||||
### resnet blocks
|
|
||||||
for i in range(n_blocks):
|
|
||||||
cur_resblock = FFCResnetBlock(
|
|
||||||
feats_num_bottleneck,
|
|
||||||
padding_type=padding_type,
|
|
||||||
activation_layer=activation_layer,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
**resnet_conv_kwargs,
|
|
||||||
)
|
|
||||||
if spatial_transform_layers is not None and i in spatial_transform_layers:
|
|
||||||
cur_resblock = LearnableSpatialTransformWrapper(
|
|
||||||
cur_resblock, **spatial_transform_kwargs
|
|
||||||
)
|
|
||||||
model += [cur_resblock]
|
|
||||||
|
|
||||||
model += [ConcatTupleLayer()]
|
|
||||||
|
|
||||||
### upsample
|
|
||||||
for i in range(n_downsampling):
|
|
||||||
mult = 2 ** (n_downsampling - i)
|
|
||||||
model += [
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
min(max_features, ngf * mult),
|
|
||||||
min(max_features, int(ngf * mult / 2)),
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
output_padding=1,
|
|
||||||
),
|
|
||||||
up_norm_layer(min(max_features, int(ngf * mult / 2))),
|
|
||||||
up_activation,
|
|
||||||
]
|
|
||||||
|
|
||||||
if out_ffc:
|
|
||||||
model += [
|
|
||||||
FFCResnetBlock(
|
|
||||||
ngf,
|
|
||||||
padding_type=padding_type,
|
|
||||||
activation_layer=activation_layer,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
inline=True,
|
|
||||||
**out_ffc_kwargs,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
model += [
|
|
||||||
nn.ReflectionPad2d(3),
|
|
||||||
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
|
||||||
]
|
|
||||||
model.append(nn.Sigmoid())
|
|
||||||
self.model = nn.Sequential(*model)
|
|
||||||
|
|
||||||
def forward(self, image, mask):
|
|
||||||
return self.model(torch.cat([image, mask], dim=1))
|
|
||||||
|
|
||||||
|
|
||||||
class LaMa(nn.Module):
|
|
||||||
def __init__(self, state_dict) -> None:
|
|
||||||
super(LaMa, self).__init__()
|
|
||||||
self.model_arch = "LaMa"
|
|
||||||
self.sub_type = "Inpaint"
|
|
||||||
self.in_nc = 4
|
|
||||||
self.out_nc = 3
|
|
||||||
self.scale = 1
|
|
||||||
|
|
||||||
self.min_size = None
|
|
||||||
self.pad_mod = 8
|
|
||||||
self.pad_to_square = False
|
|
||||||
|
|
||||||
self.model = FFCResNetGenerator(self.in_nc, self.out_nc)
|
|
||||||
self.state = {
|
|
||||||
k.replace("generator.model", "model.model"): v
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
self.supports_fp16 = False
|
|
||||||
self.support_bf16 = True
|
|
||||||
|
|
||||||
self.load_state_dict(self.state, strict=False)
|
|
||||||
|
|
||||||
def forward(self, img, mask):
|
|
||||||
masked_img = img * (1 - mask)
|
|
||||||
inpainted_mask = mask * self.model.forward(masked_img, mask)
|
|
||||||
result = inpainted_mask + (1 - mask) * img
|
|
||||||
return result
|
|
||||||
@ -1,110 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class CA_layer(nn.Module):
|
|
||||||
def __init__(self, channel, reduction=16):
|
|
||||||
super(CA_layer, self).__init__()
|
|
||||||
# global average pooling
|
|
||||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
|
|
||||||
# nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
y = self.fc(self.gap(x))
|
|
||||||
return x * y.expand_as(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Simple_CA_layer(nn.Module):
|
|
||||||
def __init__(self, channel):
|
|
||||||
super(Simple_CA_layer, self).__init__()
|
|
||||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc = nn.Conv2d(
|
|
||||||
in_channels=channel,
|
|
||||||
out_channels=channel,
|
|
||||||
kernel_size=1,
|
|
||||||
padding=0,
|
|
||||||
stride=1,
|
|
||||||
groups=1,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x * self.fc(self.gap(x))
|
|
||||||
|
|
||||||
|
|
||||||
class ECA_layer(nn.Module):
|
|
||||||
"""Constructs a ECA module.
|
|
||||||
Args:
|
|
||||||
channel: Number of channels of the input feature map
|
|
||||||
k_size: Adaptive selection of kernel size
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
super(ECA_layer, self).__init__()
|
|
||||||
|
|
||||||
b = 1
|
|
||||||
gamma = 2
|
|
||||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
|
||||||
k_size = k_size if k_size % 2 else k_size + 1
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.conv = nn.Conv1d(
|
|
||||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
|
||||||
)
|
|
||||||
# self.sigmoid = nn.Sigmoid()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: input features with shape [b, c, h, w]
|
|
||||||
# b, c, h, w = x.size()
|
|
||||||
|
|
||||||
# feature descriptor on the global spatial information
|
|
||||||
y = self.avg_pool(x)
|
|
||||||
|
|
||||||
# Two different branches of ECA module
|
|
||||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
|
||||||
|
|
||||||
# Multi-scale information fusion
|
|
||||||
# y = self.sigmoid(y)
|
|
||||||
|
|
||||||
return x * y.expand_as(x)
|
|
||||||
|
|
||||||
|
|
||||||
class ECA_MaxPool_layer(nn.Module):
|
|
||||||
"""Constructs a ECA module.
|
|
||||||
Args:
|
|
||||||
channel: Number of channels of the input feature map
|
|
||||||
k_size: Adaptive selection of kernel size
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
super(ECA_MaxPool_layer, self).__init__()
|
|
||||||
|
|
||||||
b = 1
|
|
||||||
gamma = 2
|
|
||||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
|
||||||
k_size = k_size if k_size % 2 else k_size + 1
|
|
||||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
|
||||||
self.conv = nn.Conv1d(
|
|
||||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
|
||||||
)
|
|
||||||
# self.sigmoid = nn.Sigmoid()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: input features with shape [b, c, h, w]
|
|
||||||
# b, c, h, w = x.size()
|
|
||||||
|
|
||||||
# feature descriptor on the global spatial information
|
|
||||||
y = self.max_pool(x)
|
|
||||||
|
|
||||||
# Two different branches of ECA module
|
|
||||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
|
||||||
|
|
||||||
# Multi-scale information fusion
|
|
||||||
# y = self.sigmoid(y)
|
|
||||||
|
|
||||||
return x * y.expand_as(x)
|
|
||||||
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@ -1,577 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
#############################################################
|
|
||||||
# File: OSA.py
|
|
||||||
# Created Date: Tuesday April 28th 2022
|
|
||||||
# Author: Chen Xuanhong
|
|
||||||
# Email: chenxuanhongzju@outlook.com
|
|
||||||
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
|
|
||||||
# Modified By: Chen Xuanhong
|
|
||||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
|
||||||
#############################################################
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from einops.layers.torch import Rearrange, Reduce
|
|
||||||
from torch import einsum, nn
|
|
||||||
|
|
||||||
from .layernorm import LayerNorm2d
|
|
||||||
|
|
||||||
# helpers
|
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
return val if exists(val) else d
|
|
||||||
|
|
||||||
|
|
||||||
def cast_tuple(val, length=1):
|
|
||||||
return val if isinstance(val, tuple) else ((val,) * length)
|
|
||||||
|
|
||||||
|
|
||||||
# helper classes
|
|
||||||
|
|
||||||
|
|
||||||
class PreNormResidual(nn.Module):
|
|
||||||
def __init__(self, dim, fn):
|
|
||||||
super().__init__()
|
|
||||||
self.norm = nn.LayerNorm(dim)
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.fn(self.norm(x)) + x
|
|
||||||
|
|
||||||
|
|
||||||
class Conv_PreNormResidual(nn.Module):
|
|
||||||
def __init__(self, dim, fn):
|
|
||||||
super().__init__()
|
|
||||||
self.norm = LayerNorm2d(dim)
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.fn(self.norm(x)) + x
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
def __init__(self, dim, mult=2, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = int(dim * mult)
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Linear(dim, inner_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
nn.Linear(inner_dim, dim),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Conv_FeedForward(nn.Module):
|
|
||||||
def __init__(self, dim, mult=2, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = int(dim * mult)
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
nn.Conv2d(inner_dim, dim, 1, 1, 0),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Gated_Conv_FeedForward(nn.Module):
|
|
||||||
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
hidden_features = int(dim * mult)
|
|
||||||
|
|
||||||
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
|
||||||
|
|
||||||
self.dwconv = nn.Conv2d(
|
|
||||||
hidden_features * 2,
|
|
||||||
hidden_features * 2,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
groups=hidden_features * 2,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.project_in(x)
|
|
||||||
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
|
||||||
x = F.gelu(x1) * x2
|
|
||||||
x = self.project_out(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# MBConv
|
|
||||||
|
|
||||||
|
|
||||||
class SqueezeExcitation(nn.Module):
|
|
||||||
def __init__(self, dim, shrinkage_rate=0.25):
|
|
||||||
super().__init__()
|
|
||||||
hidden_dim = int(dim * shrinkage_rate)
|
|
||||||
|
|
||||||
self.gate = nn.Sequential(
|
|
||||||
Reduce("b c h w -> b c", "mean"),
|
|
||||||
nn.Linear(dim, hidden_dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_dim, dim, bias=False),
|
|
||||||
nn.Sigmoid(),
|
|
||||||
Rearrange("b c -> b c 1 1"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x * self.gate(x)
|
|
||||||
|
|
||||||
|
|
||||||
class MBConvResidual(nn.Module):
|
|
||||||
def __init__(self, fn, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.fn = fn
|
|
||||||
self.dropsample = Dropsample(dropout)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.fn(x)
|
|
||||||
out = self.dropsample(out)
|
|
||||||
return out + x
|
|
||||||
|
|
||||||
|
|
||||||
class Dropsample(nn.Module):
|
|
||||||
def __init__(self, prob=0):
|
|
||||||
super().__init__()
|
|
||||||
self.prob = prob
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
device = x.device
|
|
||||||
|
|
||||||
if self.prob == 0.0 or (not self.training):
|
|
||||||
return x
|
|
||||||
|
|
||||||
keep_mask = (
|
|
||||||
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
|
|
||||||
> self.prob
|
|
||||||
)
|
|
||||||
return x * keep_mask / (1 - self.prob)
|
|
||||||
|
|
||||||
|
|
||||||
def MBConv(
|
|
||||||
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
|
|
||||||
):
|
|
||||||
hidden_dim = int(expansion_rate * dim_out)
|
|
||||||
stride = 2 if downsample else 1
|
|
||||||
|
|
||||||
net = nn.Sequential(
|
|
||||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
|
||||||
# nn.BatchNorm2d(hidden_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(
|
|
||||||
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
|
|
||||||
),
|
|
||||||
# nn.BatchNorm2d(hidden_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
|
|
||||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
|
||||||
# nn.BatchNorm2d(dim_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
if dim_in == dim_out and not downsample:
|
|
||||||
net = MBConvResidual(net, dropout=dropout)
|
|
||||||
|
|
||||||
return net
|
|
||||||
|
|
||||||
|
|
||||||
# attention related classes
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
dim_head=32,
|
|
||||||
dropout=0.0,
|
|
||||||
window_size=7,
|
|
||||||
with_pe=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
assert (
|
|
||||||
dim % dim_head
|
|
||||||
) == 0, "dimension should be divisible by dimension per head"
|
|
||||||
|
|
||||||
self.heads = dim // dim_head
|
|
||||||
self.scale = dim_head**-0.5
|
|
||||||
self.with_pe = with_pe
|
|
||||||
|
|
||||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
|
||||||
|
|
||||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
|
||||||
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
# relative positional bias
|
|
||||||
if self.with_pe:
|
|
||||||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
|
||||||
|
|
||||||
pos = torch.arange(window_size)
|
|
||||||
grid = torch.stack(torch.meshgrid(pos, pos))
|
|
||||||
grid = rearrange(grid, "c i j -> (i j) c")
|
|
||||||
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
|
|
||||||
grid, "j ... -> 1 j ..."
|
|
||||||
)
|
|
||||||
rel_pos += window_size - 1
|
|
||||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
|
|
||||||
dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
batch, height, width, window_height, window_width, _, device, h = (
|
|
||||||
*x.shape,
|
|
||||||
x.device,
|
|
||||||
self.heads,
|
|
||||||
)
|
|
||||||
|
|
||||||
# flatten
|
|
||||||
|
|
||||||
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
|
|
||||||
|
|
||||||
# project for queries, keys, values
|
|
||||||
|
|
||||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
|
||||||
|
|
||||||
# split heads
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
|
|
||||||
|
|
||||||
# scale
|
|
||||||
|
|
||||||
q = q * self.scale
|
|
||||||
|
|
||||||
# sim
|
|
||||||
|
|
||||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
|
||||||
|
|
||||||
# add positional bias
|
|
||||||
if self.with_pe:
|
|
||||||
bias = self.rel_pos_bias(self.rel_pos_indices)
|
|
||||||
sim = sim + rearrange(bias, "i j h -> h i j")
|
|
||||||
|
|
||||||
# attention
|
|
||||||
|
|
||||||
attn = self.attend(sim)
|
|
||||||
|
|
||||||
# aggregate
|
|
||||||
|
|
||||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
|
||||||
|
|
||||||
# merge heads
|
|
||||||
|
|
||||||
out = rearrange(
|
|
||||||
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
|
|
||||||
)
|
|
||||||
|
|
||||||
# combine heads out
|
|
||||||
|
|
||||||
out = self.to_out(out)
|
|
||||||
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
|
|
||||||
|
|
||||||
|
|
||||||
class Block_Attention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
dim_head=32,
|
|
||||||
bias=False,
|
|
||||||
dropout=0.0,
|
|
||||||
window_size=7,
|
|
||||||
with_pe=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
assert (
|
|
||||||
dim % dim_head
|
|
||||||
) == 0, "dimension should be divisible by dimension per head"
|
|
||||||
|
|
||||||
self.heads = dim // dim_head
|
|
||||||
self.ps = window_size
|
|
||||||
self.scale = dim_head**-0.5
|
|
||||||
self.with_pe = with_pe
|
|
||||||
|
|
||||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
|
||||||
self.qkv_dwconv = nn.Conv2d(
|
|
||||||
dim * 3,
|
|
||||||
dim * 3,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
groups=dim * 3,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
|
||||||
|
|
||||||
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# project for queries, keys, values
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
|
|
||||||
qkv = self.qkv_dwconv(self.qkv(x))
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
|
|
||||||
# split heads
|
|
||||||
|
|
||||||
q, k, v = map(
|
|
||||||
lambda t: rearrange(
|
|
||||||
t,
|
|
||||||
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
|
|
||||||
h=self.heads,
|
|
||||||
w1=self.ps,
|
|
||||||
w2=self.ps,
|
|
||||||
),
|
|
||||||
(q, k, v),
|
|
||||||
)
|
|
||||||
|
|
||||||
# scale
|
|
||||||
|
|
||||||
q = q * self.scale
|
|
||||||
|
|
||||||
# sim
|
|
||||||
|
|
||||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
|
||||||
|
|
||||||
# attention
|
|
||||||
attn = self.attend(sim)
|
|
||||||
|
|
||||||
# aggregate
|
|
||||||
|
|
||||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
|
||||||
|
|
||||||
# merge heads
|
|
||||||
out = rearrange(
|
|
||||||
out,
|
|
||||||
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
|
|
||||||
x=h // self.ps,
|
|
||||||
y=w // self.ps,
|
|
||||||
head=self.heads,
|
|
||||||
w1=self.ps,
|
|
||||||
w2=self.ps,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.to_out(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Channel_Attention(nn.Module):
|
|
||||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
|
||||||
super(Channel_Attention, self).__init__()
|
|
||||||
self.heads = heads
|
|
||||||
|
|
||||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
|
||||||
|
|
||||||
self.ps = window_size
|
|
||||||
|
|
||||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
|
||||||
self.qkv_dwconv = nn.Conv2d(
|
|
||||||
dim * 3,
|
|
||||||
dim * 3,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
groups=dim * 3,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
|
|
||||||
qkv = self.qkv_dwconv(self.qkv(x))
|
|
||||||
qkv = qkv.chunk(3, dim=1)
|
|
||||||
|
|
||||||
q, k, v = map(
|
|
||||||
lambda t: rearrange(
|
|
||||||
t,
|
|
||||||
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
|
|
||||||
ph=self.ps,
|
|
||||||
pw=self.ps,
|
|
||||||
head=self.heads,
|
|
||||||
),
|
|
||||||
qkv,
|
|
||||||
)
|
|
||||||
|
|
||||||
q = F.normalize(q, dim=-1)
|
|
||||||
k = F.normalize(k, dim=-1)
|
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
|
||||||
attn = attn.softmax(dim=-1)
|
|
||||||
out = attn @ v
|
|
||||||
|
|
||||||
out = rearrange(
|
|
||||||
out,
|
|
||||||
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
|
|
||||||
h=h // self.ps,
|
|
||||||
w=w // self.ps,
|
|
||||||
ph=self.ps,
|
|
||||||
pw=self.ps,
|
|
||||||
head=self.heads,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.project_out(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Channel_Attention_grid(nn.Module):
|
|
||||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
|
||||||
super(Channel_Attention_grid, self).__init__()
|
|
||||||
self.heads = heads
|
|
||||||
|
|
||||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
|
||||||
|
|
||||||
self.ps = window_size
|
|
||||||
|
|
||||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
|
||||||
self.qkv_dwconv = nn.Conv2d(
|
|
||||||
dim * 3,
|
|
||||||
dim * 3,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
groups=dim * 3,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
|
|
||||||
qkv = self.qkv_dwconv(self.qkv(x))
|
|
||||||
qkv = qkv.chunk(3, dim=1)
|
|
||||||
|
|
||||||
q, k, v = map(
|
|
||||||
lambda t: rearrange(
|
|
||||||
t,
|
|
||||||
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
|
|
||||||
ph=self.ps,
|
|
||||||
pw=self.ps,
|
|
||||||
head=self.heads,
|
|
||||||
),
|
|
||||||
qkv,
|
|
||||||
)
|
|
||||||
|
|
||||||
q = F.normalize(q, dim=-1)
|
|
||||||
k = F.normalize(k, dim=-1)
|
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
|
||||||
attn = attn.softmax(dim=-1)
|
|
||||||
out = attn @ v
|
|
||||||
|
|
||||||
out = rearrange(
|
|
||||||
out,
|
|
||||||
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
|
|
||||||
h=h // self.ps,
|
|
||||||
w=w // self.ps,
|
|
||||||
ph=self.ps,
|
|
||||||
pw=self.ps,
|
|
||||||
head=self.heads,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.project_out(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class OSA_Block(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channel_num=64,
|
|
||||||
bias=True,
|
|
||||||
ffn_bias=True,
|
|
||||||
window_size=8,
|
|
||||||
with_pe=False,
|
|
||||||
dropout=0.0,
|
|
||||||
):
|
|
||||||
super(OSA_Block, self).__init__()
|
|
||||||
|
|
||||||
w = window_size
|
|
||||||
|
|
||||||
self.layer = nn.Sequential(
|
|
||||||
MBConv(
|
|
||||||
channel_num,
|
|
||||||
channel_num,
|
|
||||||
downsample=False,
|
|
||||||
expansion_rate=1,
|
|
||||||
shrinkage_rate=0.25,
|
|
||||||
),
|
|
||||||
Rearrange(
|
|
||||||
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
|
|
||||||
), # block-like attention
|
|
||||||
PreNormResidual(
|
|
||||||
channel_num,
|
|
||||||
Attention(
|
|
||||||
dim=channel_num,
|
|
||||||
dim_head=channel_num // 4,
|
|
||||||
dropout=dropout,
|
|
||||||
window_size=window_size,
|
|
||||||
with_pe=with_pe,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
|
|
||||||
Conv_PreNormResidual(
|
|
||||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
|
||||||
),
|
|
||||||
# channel-like attention
|
|
||||||
Conv_PreNormResidual(
|
|
||||||
channel_num,
|
|
||||||
Channel_Attention(
|
|
||||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Conv_PreNormResidual(
|
|
||||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
|
||||||
),
|
|
||||||
Rearrange(
|
|
||||||
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
|
|
||||||
), # grid-like attention
|
|
||||||
PreNormResidual(
|
|
||||||
channel_num,
|
|
||||||
Attention(
|
|
||||||
dim=channel_num,
|
|
||||||
dim_head=channel_num // 4,
|
|
||||||
dropout=dropout,
|
|
||||||
window_size=window_size,
|
|
||||||
with_pe=with_pe,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
|
|
||||||
Conv_PreNormResidual(
|
|
||||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
|
||||||
),
|
|
||||||
# channel-like attention
|
|
||||||
Conv_PreNormResidual(
|
|
||||||
channel_num,
|
|
||||||
Channel_Attention_grid(
|
|
||||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Conv_PreNormResidual(
|
|
||||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.layer(x)
|
|
||||||
return out
|
|
||||||
@ -1,60 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
#############################################################
|
|
||||||
# File: OSAG.py
|
|
||||||
# Created Date: Tuesday April 28th 2022
|
|
||||||
# Author: Chen Xuanhong
|
|
||||||
# Email: chenxuanhongzju@outlook.com
|
|
||||||
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
|
|
||||||
# Modified By: Chen Xuanhong
|
|
||||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
|
||||||
#############################################################
|
|
||||||
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .esa import ESA
|
|
||||||
from .OSA import OSA_Block
|
|
||||||
|
|
||||||
|
|
||||||
class OSAG(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channel_num=64,
|
|
||||||
bias=True,
|
|
||||||
block_num=4,
|
|
||||||
ffn_bias=False,
|
|
||||||
window_size=0,
|
|
||||||
pe=False,
|
|
||||||
):
|
|
||||||
super(OSAG, self).__init__()
|
|
||||||
|
|
||||||
# print("window_size: %d" % (window_size))
|
|
||||||
# print("with_pe", pe)
|
|
||||||
# print("ffn_bias: %d" % (ffn_bias))
|
|
||||||
|
|
||||||
# block_script_name = kwargs.get("block_script_name", "OSA")
|
|
||||||
# block_class_name = kwargs.get("block_class_name", "OSA_Block")
|
|
||||||
|
|
||||||
# script_name = "." + block_script_name
|
|
||||||
# package = __import__(script_name, fromlist=True)
|
|
||||||
block_class = OSA_Block # getattr(package, block_class_name)
|
|
||||||
group_list = []
|
|
||||||
for _ in range(block_num):
|
|
||||||
temp_res = block_class(
|
|
||||||
channel_num,
|
|
||||||
bias,
|
|
||||||
ffn_bias=ffn_bias,
|
|
||||||
window_size=window_size,
|
|
||||||
with_pe=pe,
|
|
||||||
)
|
|
||||||
group_list.append(temp_res)
|
|
||||||
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
|
|
||||||
self.residual_layer = nn.Sequential(*group_list)
|
|
||||||
esa_channel = max(channel_num // 4, 16)
|
|
||||||
self.esa = ESA(esa_channel, channel_num)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.residual_layer(x)
|
|
||||||
out = out + x
|
|
||||||
return self.esa(out)
|
|
||||||
@ -1,143 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
#############################################################
|
|
||||||
# File: OmniSR.py
|
|
||||||
# Created Date: Tuesday April 28th 2022
|
|
||||||
# Author: Chen Xuanhong
|
|
||||||
# Email: chenxuanhongzju@outlook.com
|
|
||||||
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
|
|
||||||
# Modified By: Chen Xuanhong
|
|
||||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
|
||||||
#############################################################
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .OSAG import OSAG
|
|
||||||
from .pixelshuffle import pixelshuffle_block
|
|
||||||
|
|
||||||
|
|
||||||
class OmniSR(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super(OmniSR, self).__init__()
|
|
||||||
self.state = state_dict
|
|
||||||
|
|
||||||
bias = True # Fine to assume this for now
|
|
||||||
block_num = 1 # Fine to assume this for now
|
|
||||||
ffn_bias = True
|
|
||||||
pe = True
|
|
||||||
|
|
||||||
num_feat = state_dict["input.weight"].shape[0] or 64
|
|
||||||
num_in_ch = state_dict["input.weight"].shape[1] or 3
|
|
||||||
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
|
|
||||||
|
|
||||||
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
|
|
||||||
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
|
|
||||||
if up_scale - int(up_scale) > 0:
|
|
||||||
print(
|
|
||||||
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
|
||||||
)
|
|
||||||
up_scale = int(up_scale)
|
|
||||||
res_num = 0
|
|
||||||
for key in state_dict.keys():
|
|
||||||
if "residual_layer" in key:
|
|
||||||
temp_res_num = int(key.split(".")[1])
|
|
||||||
if temp_res_num > res_num:
|
|
||||||
res_num = temp_res_num
|
|
||||||
res_num = res_num + 1 # zero-indexed
|
|
||||||
|
|
||||||
residual_layer = []
|
|
||||||
self.res_num = res_num
|
|
||||||
|
|
||||||
if (
|
|
||||||
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
|
|
||||||
in state_dict.keys()
|
|
||||||
):
|
|
||||||
rel_pos_bias_weight = state_dict[
|
|
||||||
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
|
|
||||||
].shape[0]
|
|
||||||
self.window_size = int((math.sqrt(rel_pos_bias_weight) + 1) / 2)
|
|
||||||
else:
|
|
||||||
self.window_size = 8
|
|
||||||
|
|
||||||
self.up_scale = up_scale
|
|
||||||
|
|
||||||
for _ in range(res_num):
|
|
||||||
temp_res = OSAG(
|
|
||||||
channel_num=num_feat,
|
|
||||||
bias=bias,
|
|
||||||
block_num=block_num,
|
|
||||||
ffn_bias=ffn_bias,
|
|
||||||
window_size=self.window_size,
|
|
||||||
pe=pe,
|
|
||||||
)
|
|
||||||
residual_layer.append(temp_res)
|
|
||||||
self.residual_layer = nn.Sequential(*residual_layer)
|
|
||||||
self.input = nn.Conv2d(
|
|
||||||
in_channels=num_in_ch,
|
|
||||||
out_channels=num_feat,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.output = nn.Conv2d(
|
|
||||||
in_channels=num_feat,
|
|
||||||
out_channels=num_feat,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
|
|
||||||
|
|
||||||
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
|
|
||||||
|
|
||||||
# for m in self.modules():
|
|
||||||
# if isinstance(m, nn.Conv2d):
|
|
||||||
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
||||||
# m.weight.data.normal_(0, sqrt(2. / n))
|
|
||||||
|
|
||||||
# chaiNNer specific stuff
|
|
||||||
self.model_arch = "OmniSR"
|
|
||||||
self.sub_type = "SR"
|
|
||||||
self.in_nc = num_in_ch
|
|
||||||
self.out_nc = num_out_ch
|
|
||||||
self.num_feat = num_feat
|
|
||||||
self.scale = up_scale
|
|
||||||
|
|
||||||
self.supports_fp16 = True # TODO: Test this
|
|
||||||
self.supports_bfp16 = True
|
|
||||||
self.min_size_restriction = 16
|
|
||||||
|
|
||||||
self.load_state_dict(state_dict, strict=False)
|
|
||||||
|
|
||||||
def check_image_size(self, x):
|
|
||||||
_, _, h, w = x.size()
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
|
||||||
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
|
||||||
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
|
||||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
H, W = x.shape[2:]
|
|
||||||
x = self.check_image_size(x)
|
|
||||||
|
|
||||||
residual = self.input(x)
|
|
||||||
out = self.residual_layer(residual)
|
|
||||||
|
|
||||||
# origin
|
|
||||||
out = torch.add(self.output(out), residual)
|
|
||||||
out = self.up(out)
|
|
||||||
|
|
||||||
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
|
|
||||||
return out
|
|
||||||
@ -1,294 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
#############################################################
|
|
||||||
# File: esa.py
|
|
||||||
# Created Date: Tuesday April 28th 2022
|
|
||||||
# Author: Chen Xuanhong
|
|
||||||
# Email: chenxuanhongzju@outlook.com
|
|
||||||
# Last Modified: Thursday, 20th April 2023 9:28:06 am
|
|
||||||
# Modified By: Chen Xuanhong
|
|
||||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
|
||||||
#############################################################
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .layernorm import LayerNorm2d
|
|
||||||
|
|
||||||
|
|
||||||
def moment(x, dim=(2, 3), k=2):
|
|
||||||
assert len(x.size()) == 4
|
|
||||||
mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
|
|
||||||
return mk
|
|
||||||
|
|
||||||
|
|
||||||
class ESA(nn.Module):
|
|
||||||
"""
|
|
||||||
Modification of Enhanced Spatial Attention (ESA), which is proposed by
|
|
||||||
`Residual Feature Aggregation Network for Image Super-Resolution`
|
|
||||||
Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
|
|
||||||
are deleted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
|
|
||||||
super(ESA, self).__init__()
|
|
||||||
f = esa_channels
|
|
||||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
|
||||||
self.conv_f = conv(f, f, kernel_size=1)
|
|
||||||
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
|
||||||
self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
|
||||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
c1_ = self.conv1(x)
|
|
||||||
c1 = self.conv2(c1_)
|
|
||||||
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
|
||||||
c3 = self.conv3(v_max)
|
|
||||||
c3 = F.interpolate(
|
|
||||||
c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
cf = self.conv_f(c1_)
|
|
||||||
c4 = self.conv4(c3 + cf)
|
|
||||||
m = self.sigmoid(c4)
|
|
||||||
return x * m
|
|
||||||
|
|
||||||
|
|
||||||
class LK_ESA(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
|
||||||
):
|
|
||||||
super(LK_ESA, self).__init__()
|
|
||||||
f = esa_channels
|
|
||||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
|
||||||
self.conv_f = conv(f, f, kernel_size=1)
|
|
||||||
|
|
||||||
kernel_size = 17
|
|
||||||
kernel_expand = kernel_expand
|
|
||||||
padding = kernel_size // 2
|
|
||||||
|
|
||||||
self.vec_conv = nn.Conv2d(
|
|
||||||
in_channels=f * kernel_expand,
|
|
||||||
out_channels=f * kernel_expand,
|
|
||||||
kernel_size=(1, kernel_size),
|
|
||||||
padding=(0, padding),
|
|
||||||
groups=2,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.vec_conv3x1 = nn.Conv2d(
|
|
||||||
in_channels=f * kernel_expand,
|
|
||||||
out_channels=f * kernel_expand,
|
|
||||||
kernel_size=(1, 3),
|
|
||||||
padding=(0, 1),
|
|
||||||
groups=2,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hor_conv = nn.Conv2d(
|
|
||||||
in_channels=f * kernel_expand,
|
|
||||||
out_channels=f * kernel_expand,
|
|
||||||
kernel_size=(kernel_size, 1),
|
|
||||||
padding=(padding, 0),
|
|
||||||
groups=2,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.hor_conv1x3 = nn.Conv2d(
|
|
||||||
in_channels=f * kernel_expand,
|
|
||||||
out_channels=f * kernel_expand,
|
|
||||||
kernel_size=(3, 1),
|
|
||||||
padding=(1, 0),
|
|
||||||
groups=2,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
c1_ = self.conv1(x)
|
|
||||||
|
|
||||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
|
||||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
|
||||||
|
|
||||||
cf = self.conv_f(c1_)
|
|
||||||
c4 = self.conv4(res + cf)
|
|
||||||
m = self.sigmoid(c4)
|
|
||||||
return x * m
|
|
||||||
|
|
||||||
|
|
||||||
class LK_ESA_LN(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
|
||||||
):
|
|
||||||
super(LK_ESA_LN, self).__init__()
|
|
||||||
f = esa_channels
|
|
||||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
|
||||||
self.conv_f = conv(f, f, kernel_size=1)
|
|
||||||
|
|
||||||
kernel_size = 17
|
|
||||||
kernel_expand = kernel_expand
|
|
||||||
padding = kernel_size // 2
|
|
||||||
|
|
||||||
self.norm = LayerNorm2d(n_feats)
|
|
||||||
|
|
||||||
self.vec_conv = nn.Conv2d(
|
|
||||||
in_channels=f * kernel_expand,
|
|
||||||
out_channels=f * kernel_expand,
|
|
||||||
kernel_size=(1, kernel_size),
|
|
||||||
padding=(0, padding),
|
|
||||||
groups=2,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.vec_conv3x1 = nn.Conv2d(
|
|
||||||
in_channels=f * kernel_expand,
|
|
||||||
out_channels=f * kernel_expand,
|
|
||||||
kernel_size=(1, 3),
|
|
||||||
padding=(0, 1),
|
|
||||||
groups=2,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hor_conv = nn.Conv2d(
|
|
||||||
in_channels=f * kernel_expand,
|
|
||||||
out_channels=f * kernel_expand,
|
|
||||||
kernel_size=(kernel_size, 1),
|
|
||||||
padding=(padding, 0),
|
|
||||||
groups=2,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
self.hor_conv1x3 = nn.Conv2d(
|
|
||||||
in_channels=f * kernel_expand,
|
|
||||||
out_channels=f * kernel_expand,
|
|
||||||
kernel_size=(3, 1),
|
|
||||||
padding=(1, 0),
|
|
||||||
groups=2,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
c1_ = self.norm(x)
|
|
||||||
c1_ = self.conv1(c1_)
|
|
||||||
|
|
||||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
|
||||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
|
||||||
|
|
||||||
cf = self.conv_f(c1_)
|
|
||||||
c4 = self.conv4(res + cf)
|
|
||||||
m = self.sigmoid(c4)
|
|
||||||
return x * m
|
|
||||||
|
|
||||||
|
|
||||||
class AdaGuidedFilter(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
|
||||||
):
|
|
||||||
super(AdaGuidedFilter, self).__init__()
|
|
||||||
|
|
||||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc = nn.Conv2d(
|
|
||||||
in_channels=n_feats,
|
|
||||||
out_channels=1,
|
|
||||||
kernel_size=1,
|
|
||||||
padding=0,
|
|
||||||
stride=1,
|
|
||||||
groups=1,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.r = 5
|
|
||||||
|
|
||||||
def box_filter(self, x, r):
|
|
||||||
channel = x.shape[1]
|
|
||||||
kernel_size = 2 * r + 1
|
|
||||||
weight = 1.0 / (kernel_size**2)
|
|
||||||
box_kernel = weight * torch.ones(
|
|
||||||
(channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
|
|
||||||
)
|
|
||||||
output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
_, _, H, W = x.shape
|
|
||||||
N = self.box_filter(
|
|
||||||
torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
|
|
||||||
)
|
|
||||||
|
|
||||||
# epsilon = self.fc(self.gap(x))
|
|
||||||
# epsilon = torch.pow(epsilon, 2)
|
|
||||||
epsilon = 1e-2
|
|
||||||
|
|
||||||
mean_x = self.box_filter(x, self.r) / N
|
|
||||||
var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
|
|
||||||
|
|
||||||
A = var_x / (var_x + epsilon)
|
|
||||||
b = (1 - A) * mean_x
|
|
||||||
m = A * x + b
|
|
||||||
|
|
||||||
# mean_A = self.box_filter(A, self.r) / N
|
|
||||||
# mean_b = self.box_filter(b, self.r) / N
|
|
||||||
# m = mean_A * x + mean_b
|
|
||||||
return x * m
|
|
||||||
|
|
||||||
|
|
||||||
class AdaConvGuidedFilter(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
|
||||||
):
|
|
||||||
super(AdaConvGuidedFilter, self).__init__()
|
|
||||||
f = esa_channels
|
|
||||||
|
|
||||||
self.conv_f = conv(f, f, kernel_size=1)
|
|
||||||
|
|
||||||
kernel_size = 17
|
|
||||||
kernel_expand = kernel_expand
|
|
||||||
padding = kernel_size // 2
|
|
||||||
|
|
||||||
self.vec_conv = nn.Conv2d(
|
|
||||||
in_channels=f,
|
|
||||||
out_channels=f,
|
|
||||||
kernel_size=(1, kernel_size),
|
|
||||||
padding=(0, padding),
|
|
||||||
groups=f,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hor_conv = nn.Conv2d(
|
|
||||||
in_channels=f,
|
|
||||||
out_channels=f,
|
|
||||||
kernel_size=(kernel_size, 1),
|
|
||||||
padding=(padding, 0),
|
|
||||||
groups=f,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc = nn.Conv2d(
|
|
||||||
in_channels=f,
|
|
||||||
out_channels=f,
|
|
||||||
kernel_size=1,
|
|
||||||
padding=0,
|
|
||||||
stride=1,
|
|
||||||
groups=1,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
y = self.vec_conv(x)
|
|
||||||
y = self.hor_conv(y)
|
|
||||||
|
|
||||||
sigma = torch.pow(y, 2)
|
|
||||||
epsilon = self.fc(self.gap(y))
|
|
||||||
|
|
||||||
weight = sigma / (sigma + epsilon)
|
|
||||||
|
|
||||||
m = weight * x + (1 - weight)
|
|
||||||
|
|
||||||
return x * m
|
|
||||||
@ -1,70 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
#############################################################
|
|
||||||
# File: layernorm.py
|
|
||||||
# Created Date: Tuesday April 28th 2022
|
|
||||||
# Author: Chen Xuanhong
|
|
||||||
# Email: chenxuanhongzju@outlook.com
|
|
||||||
# Last Modified: Thursday, 20th April 2023 9:28:20 am
|
|
||||||
# Modified By: Chen Xuanhong
|
|
||||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
|
||||||
#############################################################
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNormFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x, weight, bias, eps):
|
|
||||||
ctx.eps = eps
|
|
||||||
N, C, H, W = x.size()
|
|
||||||
mu = x.mean(1, keepdim=True)
|
|
||||||
var = (x - mu).pow(2).mean(1, keepdim=True)
|
|
||||||
y = (x - mu) / (var + eps).sqrt()
|
|
||||||
ctx.save_for_backward(y, var, weight)
|
|
||||||
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
|
||||||
return y
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
eps = ctx.eps
|
|
||||||
|
|
||||||
N, C, H, W = grad_output.size()
|
|
||||||
y, var, weight = ctx.saved_variables
|
|
||||||
g = grad_output * weight.view(1, C, 1, 1)
|
|
||||||
mean_g = g.mean(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
|
||||||
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
|
||||||
return (
|
|
||||||
gx,
|
|
||||||
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
|
|
||||||
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm2d(nn.Module):
|
|
||||||
def __init__(self, channels, eps=1e-6):
|
|
||||||
super(LayerNorm2d, self).__init__()
|
|
||||||
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
|
|
||||||
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class GRN(nn.Module):
|
|
||||||
"""GRN (Global Response Normalization) layer"""
|
|
||||||
|
|
||||||
def __init__(self, dim):
|
|
||||||
super().__init__()
|
|
||||||
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
|
||||||
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
|
|
||||||
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
|
|
||||||
return self.gamma * (x * Nx) + self.beta + x
|
|
||||||
@ -1,31 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
#############################################################
|
|
||||||
# File: pixelshuffle.py
|
|
||||||
# Created Date: Friday July 1st 2022
|
|
||||||
# Author: Chen Xuanhong
|
|
||||||
# Email: chenxuanhongzju@outlook.com
|
|
||||||
# Last Modified: Friday, 1st July 2022 10:18:39 am
|
|
||||||
# Modified By: Chen Xuanhong
|
|
||||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
|
||||||
#############################################################
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def pixelshuffle_block(
|
|
||||||
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Upsample features according to `upscale_factor`.
|
|
||||||
"""
|
|
||||||
padding = kernel_size // 2
|
|
||||||
conv = nn.Conv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels * (upscale_factor**2),
|
|
||||||
kernel_size,
|
|
||||||
padding=1,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
|
||||||
return nn.Sequential(*[conv, pixel_shuffle])
|
|
||||||
@ -1,296 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import math
|
|
||||||
import re
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from . import block as B
|
|
||||||
|
|
||||||
|
|
||||||
# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
|
|
||||||
# Which enhanced stuff that was already here
|
|
||||||
class RRDBNet(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
norm=None,
|
|
||||||
act: str = "leakyrelu",
|
|
||||||
upsampler: str = "upconv",
|
|
||||||
mode: B.ConvMode = "CNA",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
|
|
||||||
By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
|
|
||||||
and Chen Change Loy.
|
|
||||||
This is old-arch Residual in Residual Dense Block Network and is not
|
|
||||||
the newest revision that's available at github.com/xinntao/ESRGAN.
|
|
||||||
This is on purpose, the newest Network has severely limited the
|
|
||||||
potential use of the Network with no benefits.
|
|
||||||
This network supports model files from both new and old-arch.
|
|
||||||
Args:
|
|
||||||
norm: Normalization layer
|
|
||||||
act: Activation layer
|
|
||||||
upsampler: Upsample layer. upconv, pixel_shuffle
|
|
||||||
mode: Convolution mode
|
|
||||||
"""
|
|
||||||
super(RRDBNet, self).__init__()
|
|
||||||
self.model_arch = "ESRGAN"
|
|
||||||
self.sub_type = "SR"
|
|
||||||
|
|
||||||
self.state = state_dict
|
|
||||||
self.norm = norm
|
|
||||||
self.act = act
|
|
||||||
self.upsampler = upsampler
|
|
||||||
self.mode = mode
|
|
||||||
|
|
||||||
self.state_map = {
|
|
||||||
# currently supports old, new, and newer RRDBNet arch models
|
|
||||||
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN
|
|
||||||
"model.0.weight": ("conv_first.weight",),
|
|
||||||
"model.0.bias": ("conv_first.bias",),
|
|
||||||
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
|
|
||||||
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
|
|
||||||
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
|
|
||||||
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
|
|
||||||
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
if "params_ema" in self.state:
|
|
||||||
self.state = self.state["params_ema"]
|
|
||||||
# self.model_arch = "RealESRGAN"
|
|
||||||
self.num_blocks = self.get_num_blocks()
|
|
||||||
self.plus = any("conv1x1" in k for k in self.state.keys())
|
|
||||||
if self.plus:
|
|
||||||
self.model_arch = "ESRGAN+"
|
|
||||||
|
|
||||||
self.state = self.new_to_old_arch(self.state)
|
|
||||||
|
|
||||||
self.key_arr = list(self.state.keys())
|
|
||||||
|
|
||||||
self.in_nc: int = self.state[self.key_arr[0]].shape[1]
|
|
||||||
self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
|
|
||||||
|
|
||||||
self.scale: int = self.get_scale()
|
|
||||||
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
|
||||||
|
|
||||||
c2x2 = False
|
|
||||||
if self.state["model.0.weight"].shape[-2] == 2:
|
|
||||||
c2x2 = True
|
|
||||||
self.scale = round(math.sqrt(self.scale / 4))
|
|
||||||
self.model_arch = "ESRGAN-2c2"
|
|
||||||
|
|
||||||
self.supports_fp16 = True
|
|
||||||
self.supports_bfp16 = True
|
|
||||||
self.min_size_restriction = None
|
|
||||||
|
|
||||||
# Detect if pixelunshuffle was used (Real-ESRGAN)
|
|
||||||
if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
|
|
||||||
self.in_nc / 4,
|
|
||||||
self.in_nc / 16,
|
|
||||||
):
|
|
||||||
self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
|
|
||||||
else:
|
|
||||||
self.shuffle_factor = None
|
|
||||||
|
|
||||||
upsample_block = {
|
|
||||||
"upconv": B.upconv_block,
|
|
||||||
"pixel_shuffle": B.pixelshuffle_block,
|
|
||||||
}.get(self.upsampler)
|
|
||||||
if upsample_block is None:
|
|
||||||
raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
|
|
||||||
|
|
||||||
if self.scale == 3:
|
|
||||||
upsample_blocks = upsample_block(
|
|
||||||
in_nc=self.num_filters,
|
|
||||||
out_nc=self.num_filters,
|
|
||||||
upscale_factor=3,
|
|
||||||
act_type=self.act,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
upsample_blocks = [
|
|
||||||
upsample_block(
|
|
||||||
in_nc=self.num_filters,
|
|
||||||
out_nc=self.num_filters,
|
|
||||||
act_type=self.act,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
for _ in range(int(math.log(self.scale, 2)))
|
|
||||||
]
|
|
||||||
|
|
||||||
self.model = B.sequential(
|
|
||||||
# fea conv
|
|
||||||
B.conv_block(
|
|
||||||
in_nc=self.in_nc,
|
|
||||||
out_nc=self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
c2x2=c2x2,
|
|
||||||
),
|
|
||||||
B.ShortcutBlock(
|
|
||||||
B.sequential(
|
|
||||||
# rrdb blocks
|
|
||||||
*[
|
|
||||||
B.RRDB(
|
|
||||||
nf=self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
gc=32,
|
|
||||||
stride=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=self.norm,
|
|
||||||
act_type=self.act,
|
|
||||||
mode="CNA",
|
|
||||||
plus=self.plus,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
for _ in range(self.num_blocks)
|
|
||||||
],
|
|
||||||
# lr conv
|
|
||||||
B.conv_block(
|
|
||||||
in_nc=self.num_filters,
|
|
||||||
out_nc=self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=self.norm,
|
|
||||||
act_type=None,
|
|
||||||
mode=self.mode,
|
|
||||||
c2x2=c2x2,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
*upsample_blocks,
|
|
||||||
# hr_conv0
|
|
||||||
B.conv_block(
|
|
||||||
in_nc=self.num_filters,
|
|
||||||
out_nc=self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=self.act,
|
|
||||||
c2x2=c2x2,
|
|
||||||
),
|
|
||||||
# hr_conv1
|
|
||||||
B.conv_block(
|
|
||||||
in_nc=self.num_filters,
|
|
||||||
out_nc=self.out_nc,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
c2x2=c2x2,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adjust these properties for calculations outside of the model
|
|
||||||
if self.shuffle_factor:
|
|
||||||
self.in_nc //= self.shuffle_factor**2
|
|
||||||
self.scale //= self.shuffle_factor
|
|
||||||
|
|
||||||
self.load_state_dict(self.state, strict=False)
|
|
||||||
|
|
||||||
def new_to_old_arch(self, state):
|
|
||||||
"""Convert a new-arch model state dictionary to an old-arch dictionary."""
|
|
||||||
if "params_ema" in state:
|
|
||||||
state = state["params_ema"]
|
|
||||||
|
|
||||||
if "conv_first.weight" not in state:
|
|
||||||
# model is already old arch, this is a loose check, but should be sufficient
|
|
||||||
return state
|
|
||||||
|
|
||||||
# add nb to state keys
|
|
||||||
for kind in ("weight", "bias"):
|
|
||||||
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
|
|
||||||
f"model.1.sub./NB/.{kind}"
|
|
||||||
]
|
|
||||||
del self.state_map[f"model.1.sub./NB/.{kind}"]
|
|
||||||
|
|
||||||
old_state = OrderedDict()
|
|
||||||
for old_key, new_keys in self.state_map.items():
|
|
||||||
for new_key in new_keys:
|
|
||||||
if r"\1" in old_key:
|
|
||||||
for k, v in state.items():
|
|
||||||
sub = re.sub(new_key, old_key, k)
|
|
||||||
if sub != k:
|
|
||||||
old_state[sub] = v
|
|
||||||
else:
|
|
||||||
if new_key in state:
|
|
||||||
old_state[old_key] = state[new_key]
|
|
||||||
|
|
||||||
# upconv layers
|
|
||||||
max_upconv = 0
|
|
||||||
for key in state.keys():
|
|
||||||
match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
|
|
||||||
if match is not None:
|
|
||||||
_, key_num, key_type = match.groups()
|
|
||||||
old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
|
|
||||||
max_upconv = max(max_upconv, int(key_num) * 3)
|
|
||||||
|
|
||||||
# final layers
|
|
||||||
for key in state.keys():
|
|
||||||
if key in ("HRconv.weight", "conv_hr.weight"):
|
|
||||||
old_state[f"model.{max_upconv + 2}.weight"] = state[key]
|
|
||||||
elif key in ("HRconv.bias", "conv_hr.bias"):
|
|
||||||
old_state[f"model.{max_upconv + 2}.bias"] = state[key]
|
|
||||||
elif key in ("conv_last.weight",):
|
|
||||||
old_state[f"model.{max_upconv + 4}.weight"] = state[key]
|
|
||||||
elif key in ("conv_last.bias",):
|
|
||||||
old_state[f"model.{max_upconv + 4}.bias"] = state[key]
|
|
||||||
|
|
||||||
# Sort by first numeric value of each layer
|
|
||||||
def compare(item1, item2):
|
|
||||||
parts1 = item1.split(".")
|
|
||||||
parts2 = item2.split(".")
|
|
||||||
int1 = int(parts1[1])
|
|
||||||
int2 = int(parts2[1])
|
|
||||||
return int1 - int2
|
|
||||||
|
|
||||||
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
|
|
||||||
|
|
||||||
# Rebuild the output dict in the right order
|
|
||||||
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
|
|
||||||
|
|
||||||
return out_dict
|
|
||||||
|
|
||||||
def get_scale(self, min_part: int = 6) -> int:
|
|
||||||
n = 0
|
|
||||||
for part in list(self.state):
|
|
||||||
parts = part.split(".")[1:]
|
|
||||||
if len(parts) == 2:
|
|
||||||
part_num = int(parts[0])
|
|
||||||
if part_num > min_part and parts[1] == "weight":
|
|
||||||
n += 1
|
|
||||||
return 2**n
|
|
||||||
|
|
||||||
def get_num_blocks(self) -> int:
|
|
||||||
nbs = []
|
|
||||||
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
|
|
||||||
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
|
|
||||||
)
|
|
||||||
for state_key in state_keys:
|
|
||||||
for k in self.state:
|
|
||||||
m = re.search(state_key, k)
|
|
||||||
if m:
|
|
||||||
nbs.append(int(m.group(1)))
|
|
||||||
if nbs:
|
|
||||||
break
|
|
||||||
return max(*nbs) + 1
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.shuffle_factor:
|
|
||||||
_, _, h, w = x.size()
|
|
||||||
mod_pad_h = (
|
|
||||||
self.shuffle_factor - h % self.shuffle_factor
|
|
||||||
) % self.shuffle_factor
|
|
||||||
mod_pad_w = (
|
|
||||||
self.shuffle_factor - w % self.shuffle_factor
|
|
||||||
) % self.shuffle_factor
|
|
||||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
|
|
||||||
x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
|
|
||||||
x = self.model(x)
|
|
||||||
return x[:, :, : h * self.scale, : w * self.scale]
|
|
||||||
return self.model(x)
|
|
||||||
@ -1,455 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# -----------------------------------------------------------------------------------
|
|
||||||
# SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278
|
|
||||||
# Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc
|
|
||||||
# -----------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange
|
|
||||||
from einops.layers.torch import Rearrange
|
|
||||||
|
|
||||||
from .timm.drop import DropPath
|
|
||||||
from .timm.weight_init import trunc_normal_
|
|
||||||
|
|
||||||
|
|
||||||
# Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py
|
|
||||||
class WMSA(nn.Module):
|
|
||||||
"""Self-attention module in Swin Transformer"""
|
|
||||||
|
|
||||||
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
|
||||||
super(WMSA, self).__init__()
|
|
||||||
self.input_dim = input_dim
|
|
||||||
self.output_dim = output_dim
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.scale = self.head_dim**-0.5
|
|
||||||
self.n_heads = input_dim // head_dim
|
|
||||||
self.window_size = window_size
|
|
||||||
self.type = type
|
|
||||||
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
|
||||||
|
|
||||||
self.relative_position_params = nn.Parameter(
|
|
||||||
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
|
|
||||||
)
|
|
||||||
# TODO recover
|
|
||||||
# self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
|
|
||||||
self.relative_position_params = nn.Parameter(
|
|
||||||
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
|
||||||
|
|
||||||
trunc_normal_(self.relative_position_params, std=0.02)
|
|
||||||
self.relative_position_params = torch.nn.Parameter(
|
|
||||||
self.relative_position_params.view(
|
|
||||||
2 * window_size - 1, 2 * window_size - 1, self.n_heads
|
|
||||||
)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_mask(self, h, w, p, shift):
|
|
||||||
"""generating the mask of SW-MSA
|
|
||||||
Args:
|
|
||||||
shift: shift parameters in CyclicShift.
|
|
||||||
Returns:
|
|
||||||
attn_mask: should be (1 1 w p p),
|
|
||||||
"""
|
|
||||||
# supporting square.
|
|
||||||
attn_mask = torch.zeros(
|
|
||||||
h,
|
|
||||||
w,
|
|
||||||
p,
|
|
||||||
p,
|
|
||||||
p,
|
|
||||||
p,
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=self.relative_position_params.device,
|
|
||||||
)
|
|
||||||
if self.type == "W":
|
|
||||||
return attn_mask
|
|
||||||
|
|
||||||
s = p - shift
|
|
||||||
attn_mask[-1, :, :s, :, s:, :] = True
|
|
||||||
attn_mask[-1, :, s:, :, :s, :] = True
|
|
||||||
attn_mask[:, -1, :, :s, :, s:] = True
|
|
||||||
attn_mask[:, -1, :, s:, :, :s] = True
|
|
||||||
attn_mask = rearrange(
|
|
||||||
attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)"
|
|
||||||
)
|
|
||||||
return attn_mask
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass of Window Multi-head Self-attention module.
|
|
||||||
Args:
|
|
||||||
x: input tensor with shape of [b h w c];
|
|
||||||
attn_mask: attention mask, fill -inf where the value is True;
|
|
||||||
Returns:
|
|
||||||
output: tensor shape [b h w c]
|
|
||||||
"""
|
|
||||||
if self.type != "W":
|
|
||||||
x = torch.roll(
|
|
||||||
x,
|
|
||||||
shifts=(-(self.window_size // 2), -(self.window_size // 2)),
|
|
||||||
dims=(1, 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
x = rearrange(
|
|
||||||
x,
|
|
||||||
"b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c",
|
|
||||||
p1=self.window_size,
|
|
||||||
p2=self.window_size,
|
|
||||||
)
|
|
||||||
h_windows = x.size(1)
|
|
||||||
w_windows = x.size(2)
|
|
||||||
# square validation
|
|
||||||
# assert h_windows == w_windows
|
|
||||||
|
|
||||||
x = rearrange(
|
|
||||||
x,
|
|
||||||
"b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c",
|
|
||||||
p1=self.window_size,
|
|
||||||
p2=self.window_size,
|
|
||||||
)
|
|
||||||
qkv = self.embedding_layer(x)
|
|
||||||
q, k, v = rearrange(
|
|
||||||
qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim
|
|
||||||
).chunk(3, dim=0)
|
|
||||||
sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale
|
|
||||||
# Adding learnable relative embedding
|
|
||||||
sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q")
|
|
||||||
# Using Attn Mask to distinguish different subwindows.
|
|
||||||
if self.type != "W":
|
|
||||||
attn_mask = self.generate_mask(
|
|
||||||
h_windows, w_windows, self.window_size, shift=self.window_size // 2
|
|
||||||
)
|
|
||||||
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
|
||||||
|
|
||||||
probs = nn.functional.softmax(sim, dim=-1)
|
|
||||||
output = torch.einsum("hbwij,hbwjc->hbwic", probs, v)
|
|
||||||
output = rearrange(output, "h b w p c -> b w p (h c)")
|
|
||||||
output = self.linear(output)
|
|
||||||
output = rearrange(
|
|
||||||
output,
|
|
||||||
"b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c",
|
|
||||||
w1=h_windows,
|
|
||||||
p1=self.window_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.type != "W":
|
|
||||||
output = torch.roll(
|
|
||||||
output,
|
|
||||||
shifts=(self.window_size // 2, self.window_size // 2),
|
|
||||||
dims=(1, 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def relative_embedding(self):
|
|
||||||
cord = torch.tensor(
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
[i, j]
|
|
||||||
for i in range(self.window_size)
|
|
||||||
for j in range(self.window_size)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
|
||||||
# negative is allowed
|
|
||||||
return self.relative_position_params[
|
|
||||||
:, relation[:, :, 0].long(), relation[:, :, 1].long()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_dim,
|
|
||||||
output_dim,
|
|
||||||
head_dim,
|
|
||||||
window_size,
|
|
||||||
drop_path,
|
|
||||||
type="W",
|
|
||||||
input_resolution=None,
|
|
||||||
):
|
|
||||||
"""SwinTransformer Block"""
|
|
||||||
super(Block, self).__init__()
|
|
||||||
self.input_dim = input_dim
|
|
||||||
self.output_dim = output_dim
|
|
||||||
assert type in ["W", "SW"]
|
|
||||||
self.type = type
|
|
||||||
if input_resolution <= window_size:
|
|
||||||
self.type = "W"
|
|
||||||
|
|
||||||
self.ln1 = nn.LayerNorm(input_dim)
|
|
||||||
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
||||||
self.ln2 = nn.LayerNorm(input_dim)
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(input_dim, 4 * input_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(4 * input_dim, output_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x + self.drop_path(self.msa(self.ln1(x)))
|
|
||||||
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ConvTransBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
conv_dim,
|
|
||||||
trans_dim,
|
|
||||||
head_dim,
|
|
||||||
window_size,
|
|
||||||
drop_path,
|
|
||||||
type="W",
|
|
||||||
input_resolution=None,
|
|
||||||
):
|
|
||||||
"""SwinTransformer and Conv Block"""
|
|
||||||
super(ConvTransBlock, self).__init__()
|
|
||||||
self.conv_dim = conv_dim
|
|
||||||
self.trans_dim = trans_dim
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.window_size = window_size
|
|
||||||
self.drop_path = drop_path
|
|
||||||
self.type = type
|
|
||||||
self.input_resolution = input_resolution
|
|
||||||
|
|
||||||
assert self.type in ["W", "SW"]
|
|
||||||
if self.input_resolution <= self.window_size:
|
|
||||||
self.type = "W"
|
|
||||||
|
|
||||||
self.trans_block = Block(
|
|
||||||
self.trans_dim,
|
|
||||||
self.trans_dim,
|
|
||||||
self.head_dim,
|
|
||||||
self.window_size,
|
|
||||||
self.drop_path,
|
|
||||||
self.type,
|
|
||||||
self.input_resolution,
|
|
||||||
)
|
|
||||||
self.conv1_1 = nn.Conv2d(
|
|
||||||
self.conv_dim + self.trans_dim,
|
|
||||||
self.conv_dim + self.trans_dim,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
self.conv1_2 = nn.Conv2d(
|
|
||||||
self.conv_dim + self.trans_dim,
|
|
||||||
self.conv_dim + self.trans_dim,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
0,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_block = nn.Sequential(
|
|
||||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
conv_x, trans_x = torch.split(
|
|
||||||
self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1
|
|
||||||
)
|
|
||||||
conv_x = self.conv_block(conv_x) + conv_x
|
|
||||||
trans_x = Rearrange("b c h w -> b h w c")(trans_x)
|
|
||||||
trans_x = self.trans_block(trans_x)
|
|
||||||
trans_x = Rearrange("b h w c -> b c h w")(trans_x)
|
|
||||||
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
|
||||||
x = x + res
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SCUNet(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
in_nc=3,
|
|
||||||
config=[4, 4, 4, 4, 4, 4, 4],
|
|
||||||
dim=64,
|
|
||||||
drop_path_rate=0.0,
|
|
||||||
input_resolution=256,
|
|
||||||
):
|
|
||||||
super(SCUNet, self).__init__()
|
|
||||||
self.model_arch = "SCUNet"
|
|
||||||
self.sub_type = "SR"
|
|
||||||
|
|
||||||
self.num_filters: int = 0
|
|
||||||
|
|
||||||
self.state = state_dict
|
|
||||||
self.config = config
|
|
||||||
self.dim = dim
|
|
||||||
self.head_dim = 32
|
|
||||||
self.window_size = 8
|
|
||||||
|
|
||||||
self.in_nc = in_nc
|
|
||||||
self.out_nc = self.in_nc
|
|
||||||
self.scale = 1
|
|
||||||
self.supports_fp16 = True
|
|
||||||
|
|
||||||
# drop path rate for each layer
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
|
||||||
|
|
||||||
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
|
||||||
|
|
||||||
begin = 0
|
|
||||||
self.m_down1 = [
|
|
||||||
ConvTransBlock(
|
|
||||||
dim // 2,
|
|
||||||
dim // 2,
|
|
||||||
self.head_dim,
|
|
||||||
self.window_size,
|
|
||||||
dpr[i + begin],
|
|
||||||
"W" if not i % 2 else "SW",
|
|
||||||
input_resolution,
|
|
||||||
)
|
|
||||||
for i in range(config[0])
|
|
||||||
] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
|
||||||
|
|
||||||
begin += config[0]
|
|
||||||
self.m_down2 = [
|
|
||||||
ConvTransBlock(
|
|
||||||
dim,
|
|
||||||
dim,
|
|
||||||
self.head_dim,
|
|
||||||
self.window_size,
|
|
||||||
dpr[i + begin],
|
|
||||||
"W" if not i % 2 else "SW",
|
|
||||||
input_resolution // 2,
|
|
||||||
)
|
|
||||||
for i in range(config[1])
|
|
||||||
] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
|
||||||
|
|
||||||
begin += config[1]
|
|
||||||
self.m_down3 = [
|
|
||||||
ConvTransBlock(
|
|
||||||
2 * dim,
|
|
||||||
2 * dim,
|
|
||||||
self.head_dim,
|
|
||||||
self.window_size,
|
|
||||||
dpr[i + begin],
|
|
||||||
"W" if not i % 2 else "SW",
|
|
||||||
input_resolution // 4,
|
|
||||||
)
|
|
||||||
for i in range(config[2])
|
|
||||||
] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
|
||||||
|
|
||||||
begin += config[2]
|
|
||||||
self.m_body = [
|
|
||||||
ConvTransBlock(
|
|
||||||
4 * dim,
|
|
||||||
4 * dim,
|
|
||||||
self.head_dim,
|
|
||||||
self.window_size,
|
|
||||||
dpr[i + begin],
|
|
||||||
"W" if not i % 2 else "SW",
|
|
||||||
input_resolution // 8,
|
|
||||||
)
|
|
||||||
for i in range(config[3])
|
|
||||||
]
|
|
||||||
|
|
||||||
begin += config[3]
|
|
||||||
self.m_up3 = [
|
|
||||||
nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False),
|
|
||||||
] + [
|
|
||||||
ConvTransBlock(
|
|
||||||
2 * dim,
|
|
||||||
2 * dim,
|
|
||||||
self.head_dim,
|
|
||||||
self.window_size,
|
|
||||||
dpr[i + begin],
|
|
||||||
"W" if not i % 2 else "SW",
|
|
||||||
input_resolution // 4,
|
|
||||||
)
|
|
||||||
for i in range(config[4])
|
|
||||||
]
|
|
||||||
|
|
||||||
begin += config[4]
|
|
||||||
self.m_up2 = [
|
|
||||||
nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False),
|
|
||||||
] + [
|
|
||||||
ConvTransBlock(
|
|
||||||
dim,
|
|
||||||
dim,
|
|
||||||
self.head_dim,
|
|
||||||
self.window_size,
|
|
||||||
dpr[i + begin],
|
|
||||||
"W" if not i % 2 else "SW",
|
|
||||||
input_resolution // 2,
|
|
||||||
)
|
|
||||||
for i in range(config[5])
|
|
||||||
]
|
|
||||||
|
|
||||||
begin += config[5]
|
|
||||||
self.m_up1 = [
|
|
||||||
nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False),
|
|
||||||
] + [
|
|
||||||
ConvTransBlock(
|
|
||||||
dim // 2,
|
|
||||||
dim // 2,
|
|
||||||
self.head_dim,
|
|
||||||
self.window_size,
|
|
||||||
dpr[i + begin],
|
|
||||||
"W" if not i % 2 else "SW",
|
|
||||||
input_resolution,
|
|
||||||
)
|
|
||||||
for i in range(config[6])
|
|
||||||
]
|
|
||||||
|
|
||||||
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
|
||||||
|
|
||||||
self.m_head = nn.Sequential(*self.m_head)
|
|
||||||
self.m_down1 = nn.Sequential(*self.m_down1)
|
|
||||||
self.m_down2 = nn.Sequential(*self.m_down2)
|
|
||||||
self.m_down3 = nn.Sequential(*self.m_down3)
|
|
||||||
self.m_body = nn.Sequential(*self.m_body)
|
|
||||||
self.m_up3 = nn.Sequential(*self.m_up3)
|
|
||||||
self.m_up2 = nn.Sequential(*self.m_up2)
|
|
||||||
self.m_up1 = nn.Sequential(*self.m_up1)
|
|
||||||
self.m_tail = nn.Sequential(*self.m_tail)
|
|
||||||
# self.apply(self._init_weights)
|
|
||||||
self.load_state_dict(state_dict, strict=True)
|
|
||||||
|
|
||||||
def check_image_size(self, x):
|
|
||||||
_, _, h, w = x.size()
|
|
||||||
mod_pad_h = (64 - h % 64) % 64
|
|
||||||
mod_pad_w = (64 - w % 64) % 64
|
|
||||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x0):
|
|
||||||
h, w = x0.size()[-2:]
|
|
||||||
x0 = self.check_image_size(x0)
|
|
||||||
|
|
||||||
x1 = self.m_head(x0)
|
|
||||||
x2 = self.m_down1(x1)
|
|
||||||
x3 = self.m_down2(x2)
|
|
||||||
x4 = self.m_down3(x3)
|
|
||||||
x = self.m_body(x4)
|
|
||||||
x = self.m_up3(x + x4)
|
|
||||||
x = self.m_up2(x + x3)
|
|
||||||
x = self.m_up1(x + x2)
|
|
||||||
x = self.m_tail(x + x1)
|
|
||||||
|
|
||||||
x = x[:, :, :h, :w]
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=0.02)
|
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
@ -1,383 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from . import block as B
|
|
||||||
|
|
||||||
|
|
||||||
class Get_gradient_nopadding(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(Get_gradient_nopadding, self).__init__()
|
|
||||||
kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]]
|
|
||||||
kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]]
|
|
||||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
|
||||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
|
||||||
self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False) # type: ignore
|
|
||||||
|
|
||||||
self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False) # type: ignore
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_list = []
|
|
||||||
for i in range(x.shape[1]):
|
|
||||||
x_i = x[:, i]
|
|
||||||
x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
|
|
||||||
x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
|
|
||||||
x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
|
|
||||||
x_list.append(x_i)
|
|
||||||
|
|
||||||
x = torch.cat(x_list, dim=1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SPSRNet(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
norm=None,
|
|
||||||
act: str = "leakyrelu",
|
|
||||||
upsampler: str = "upconv",
|
|
||||||
mode: B.ConvMode = "CNA",
|
|
||||||
):
|
|
||||||
super(SPSRNet, self).__init__()
|
|
||||||
self.model_arch = "SPSR"
|
|
||||||
self.sub_type = "SR"
|
|
||||||
|
|
||||||
self.state = state_dict
|
|
||||||
self.norm = norm
|
|
||||||
self.act = act
|
|
||||||
self.upsampler = upsampler
|
|
||||||
self.mode = mode
|
|
||||||
|
|
||||||
self.num_blocks = self.get_num_blocks()
|
|
||||||
|
|
||||||
self.in_nc: int = self.state["model.0.weight"].shape[1]
|
|
||||||
self.out_nc: int = self.state["f_HR_conv1.0.bias"].shape[0]
|
|
||||||
|
|
||||||
self.scale = self.get_scale(4)
|
|
||||||
self.num_filters: int = self.state["model.0.weight"].shape[0]
|
|
||||||
|
|
||||||
self.supports_fp16 = True
|
|
||||||
self.supports_bfp16 = True
|
|
||||||
self.min_size_restriction = None
|
|
||||||
|
|
||||||
n_upscale = int(math.log(self.scale, 2))
|
|
||||||
if self.scale == 3:
|
|
||||||
n_upscale = 1
|
|
||||||
|
|
||||||
fea_conv = B.conv_block(
|
|
||||||
self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
|
|
||||||
)
|
|
||||||
rb_blocks = [
|
|
||||||
B.RRDB(
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
gc=32,
|
|
||||||
stride=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=norm,
|
|
||||||
act_type=act,
|
|
||||||
mode="CNA",
|
|
||||||
)
|
|
||||||
for _ in range(self.num_blocks)
|
|
||||||
]
|
|
||||||
LR_conv = B.conv_block(
|
|
||||||
self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=norm,
|
|
||||||
act_type=None,
|
|
||||||
mode=mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
if upsampler == "upconv":
|
|
||||||
upsample_block = B.upconv_block
|
|
||||||
elif upsampler == "pixelshuffle":
|
|
||||||
upsample_block = B.pixelshuffle_block
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
|
|
||||||
if self.scale == 3:
|
|
||||||
a_upsampler = upsample_block(
|
|
||||||
self.num_filters, self.num_filters, 3, act_type=act
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
a_upsampler = [
|
|
||||||
upsample_block(self.num_filters, self.num_filters, act_type=act)
|
|
||||||
for _ in range(n_upscale)
|
|
||||||
]
|
|
||||||
self.HR_conv0_new = B.conv_block(
|
|
||||||
self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=act,
|
|
||||||
)
|
|
||||||
self.HR_conv1_new = B.conv_block(
|
|
||||||
self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = B.sequential(
|
|
||||||
fea_conv,
|
|
||||||
B.ShortcutBlockSPSR(B.sequential(*rb_blocks, LR_conv)),
|
|
||||||
*a_upsampler,
|
|
||||||
self.HR_conv0_new,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.get_g_nopadding = Get_gradient_nopadding()
|
|
||||||
|
|
||||||
self.b_fea_conv = B.conv_block(
|
|
||||||
self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.b_concat_1 = B.conv_block(
|
|
||||||
2 * self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
)
|
|
||||||
self.b_block_1 = B.RRDB(
|
|
||||||
self.num_filters * 2,
|
|
||||||
kernel_size=3,
|
|
||||||
gc=32,
|
|
||||||
stride=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=norm,
|
|
||||||
act_type=act,
|
|
||||||
mode="CNA",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.b_concat_2 = B.conv_block(
|
|
||||||
2 * self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
)
|
|
||||||
self.b_block_2 = B.RRDB(
|
|
||||||
self.num_filters * 2,
|
|
||||||
kernel_size=3,
|
|
||||||
gc=32,
|
|
||||||
stride=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=norm,
|
|
||||||
act_type=act,
|
|
||||||
mode="CNA",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.b_concat_3 = B.conv_block(
|
|
||||||
2 * self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
)
|
|
||||||
self.b_block_3 = B.RRDB(
|
|
||||||
self.num_filters * 2,
|
|
||||||
kernel_size=3,
|
|
||||||
gc=32,
|
|
||||||
stride=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=norm,
|
|
||||||
act_type=act,
|
|
||||||
mode="CNA",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.b_concat_4 = B.conv_block(
|
|
||||||
2 * self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
)
|
|
||||||
self.b_block_4 = B.RRDB(
|
|
||||||
self.num_filters * 2,
|
|
||||||
kernel_size=3,
|
|
||||||
gc=32,
|
|
||||||
stride=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=norm,
|
|
||||||
act_type=act,
|
|
||||||
mode="CNA",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.b_LR_conv = B.conv_block(
|
|
||||||
self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=norm,
|
|
||||||
act_type=None,
|
|
||||||
mode=mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
if upsampler == "upconv":
|
|
||||||
upsample_block = B.upconv_block
|
|
||||||
elif upsampler == "pixelshuffle":
|
|
||||||
upsample_block = B.pixelshuffle_block
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
|
|
||||||
if self.scale == 3:
|
|
||||||
b_upsampler = upsample_block(
|
|
||||||
self.num_filters, self.num_filters, 3, act_type=act
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
b_upsampler = [
|
|
||||||
upsample_block(self.num_filters, self.num_filters, act_type=act)
|
|
||||||
for _ in range(n_upscale)
|
|
||||||
]
|
|
||||||
|
|
||||||
b_HR_conv0 = B.conv_block(
|
|
||||||
self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=act,
|
|
||||||
)
|
|
||||||
b_HR_conv1 = B.conv_block(
|
|
||||||
self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
|
|
||||||
|
|
||||||
self.conv_w = B.conv_block(
|
|
||||||
self.num_filters, self.out_nc, kernel_size=1, norm_type=None, act_type=None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.f_concat = B.conv_block(
|
|
||||||
self.num_filters * 2,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.f_block = B.RRDB(
|
|
||||||
self.num_filters * 2,
|
|
||||||
kernel_size=3,
|
|
||||||
gc=32,
|
|
||||||
stride=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=norm,
|
|
||||||
act_type=act,
|
|
||||||
mode="CNA",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.f_HR_conv0 = B.conv_block(
|
|
||||||
self.num_filters,
|
|
||||||
self.num_filters,
|
|
||||||
kernel_size=3,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=act,
|
|
||||||
)
|
|
||||||
self.f_HR_conv1 = B.conv_block(
|
|
||||||
self.num_filters, self.out_nc, kernel_size=3, norm_type=None, act_type=None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.load_state_dict(self.state, strict=False)
|
|
||||||
|
|
||||||
def get_scale(self, min_part: int = 4) -> int:
|
|
||||||
n = 0
|
|
||||||
for part in list(self.state):
|
|
||||||
parts = part.split(".")
|
|
||||||
if len(parts) == 3:
|
|
||||||
part_num = int(parts[1])
|
|
||||||
if part_num > min_part and parts[0] == "model" and parts[2] == "weight":
|
|
||||||
n += 1
|
|
||||||
return 2**n
|
|
||||||
|
|
||||||
def get_num_blocks(self) -> int:
|
|
||||||
nb = 0
|
|
||||||
for part in list(self.state):
|
|
||||||
parts = part.split(".")
|
|
||||||
n_parts = len(parts)
|
|
||||||
if n_parts == 5 and parts[2] == "sub":
|
|
||||||
nb = int(parts[3])
|
|
||||||
return nb
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_grad = self.get_g_nopadding(x)
|
|
||||||
x = self.model[0](x)
|
|
||||||
|
|
||||||
x, block_list = self.model[1](x)
|
|
||||||
|
|
||||||
x_ori = x
|
|
||||||
for i in range(5):
|
|
||||||
x = block_list[i](x)
|
|
||||||
x_fea1 = x
|
|
||||||
|
|
||||||
for i in range(5):
|
|
||||||
x = block_list[i + 5](x)
|
|
||||||
x_fea2 = x
|
|
||||||
|
|
||||||
for i in range(5):
|
|
||||||
x = block_list[i + 10](x)
|
|
||||||
x_fea3 = x
|
|
||||||
|
|
||||||
for i in range(5):
|
|
||||||
x = block_list[i + 15](x)
|
|
||||||
x_fea4 = x
|
|
||||||
|
|
||||||
x = block_list[20:](x)
|
|
||||||
# short cut
|
|
||||||
x = x_ori + x
|
|
||||||
x = self.model[2:](x)
|
|
||||||
x = self.HR_conv1_new(x)
|
|
||||||
|
|
||||||
x_b_fea = self.b_fea_conv(x_grad)
|
|
||||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
|
||||||
|
|
||||||
x_cat_1 = self.b_block_1(x_cat_1)
|
|
||||||
x_cat_1 = self.b_concat_1(x_cat_1)
|
|
||||||
|
|
||||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
|
||||||
|
|
||||||
x_cat_2 = self.b_block_2(x_cat_2)
|
|
||||||
x_cat_2 = self.b_concat_2(x_cat_2)
|
|
||||||
|
|
||||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
|
||||||
|
|
||||||
x_cat_3 = self.b_block_3(x_cat_3)
|
|
||||||
x_cat_3 = self.b_concat_3(x_cat_3)
|
|
||||||
|
|
||||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
|
||||||
|
|
||||||
x_cat_4 = self.b_block_4(x_cat_4)
|
|
||||||
x_cat_4 = self.b_concat_4(x_cat_4)
|
|
||||||
|
|
||||||
x_cat_4 = self.b_LR_conv(x_cat_4)
|
|
||||||
|
|
||||||
# short cut
|
|
||||||
x_cat_4 = x_cat_4 + x_b_fea
|
|
||||||
x_branch = self.b_module(x_cat_4)
|
|
||||||
|
|
||||||
# x_out_branch = self.conv_w(x_branch)
|
|
||||||
########
|
|
||||||
x_branch_d = x_branch
|
|
||||||
x_f_cat = torch.cat([x_branch_d, x], dim=1)
|
|
||||||
x_f_cat = self.f_block(x_f_cat)
|
|
||||||
x_out = self.f_concat(x_f_cat)
|
|
||||||
x_out = self.f_HR_conv0(x_out)
|
|
||||||
x_out = self.f_HR_conv1(x_out)
|
|
||||||
|
|
||||||
#########
|
|
||||||
# return x_out_branch, x_out, x_grad
|
|
||||||
return x_out
|
|
||||||
@ -1,114 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class SRVGGNetCompact(nn.Module):
|
|
||||||
"""A compact VGG-style network structure for super-resolution.
|
|
||||||
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
|
||||||
conducted on the HR feature space.
|
|
||||||
Args:
|
|
||||||
num_in_ch (int): Channel number of inputs. Default: 3.
|
|
||||||
num_out_ch (int): Channel number of outputs. Default: 3.
|
|
||||||
num_feat (int): Channel number of intermediate features. Default: 64.
|
|
||||||
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
|
||||||
upscale (int): Upsampling factor. Default: 4.
|
|
||||||
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
act_type: str = "prelu",
|
|
||||||
):
|
|
||||||
super(SRVGGNetCompact, self).__init__()
|
|
||||||
self.model_arch = "SRVGG (RealESRGAN)"
|
|
||||||
self.sub_type = "SR"
|
|
||||||
|
|
||||||
self.act_type = act_type
|
|
||||||
|
|
||||||
self.state = state_dict
|
|
||||||
|
|
||||||
if "params" in self.state:
|
|
||||||
self.state = self.state["params"]
|
|
||||||
|
|
||||||
self.key_arr = list(self.state.keys())
|
|
||||||
|
|
||||||
self.in_nc = self.get_in_nc()
|
|
||||||
self.num_feat = self.get_num_feats()
|
|
||||||
self.num_conv = self.get_num_conv()
|
|
||||||
self.out_nc = self.in_nc # :(
|
|
||||||
self.pixelshuffle_shape = None # Defined in get_scale()
|
|
||||||
self.scale = self.get_scale()
|
|
||||||
|
|
||||||
self.supports_fp16 = True
|
|
||||||
self.supports_bfp16 = True
|
|
||||||
self.min_size_restriction = None
|
|
||||||
|
|
||||||
self.body = nn.ModuleList()
|
|
||||||
# the first conv
|
|
||||||
self.body.append(nn.Conv2d(self.in_nc, self.num_feat, 3, 1, 1))
|
|
||||||
# the first activation
|
|
||||||
if act_type == "relu":
|
|
||||||
activation = nn.ReLU(inplace=True)
|
|
||||||
elif act_type == "prelu":
|
|
||||||
activation = nn.PReLU(num_parameters=self.num_feat)
|
|
||||||
elif act_type == "leakyrelu":
|
|
||||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
||||||
self.body.append(activation) # type: ignore
|
|
||||||
|
|
||||||
# the body structure
|
|
||||||
for _ in range(self.num_conv):
|
|
||||||
self.body.append(nn.Conv2d(self.num_feat, self.num_feat, 3, 1, 1))
|
|
||||||
# activation
|
|
||||||
if act_type == "relu":
|
|
||||||
activation = nn.ReLU(inplace=True)
|
|
||||||
elif act_type == "prelu":
|
|
||||||
activation = nn.PReLU(num_parameters=self.num_feat)
|
|
||||||
elif act_type == "leakyrelu":
|
|
||||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
||||||
self.body.append(activation) # type: ignore
|
|
||||||
|
|
||||||
# the last conv
|
|
||||||
self.body.append(nn.Conv2d(self.num_feat, self.pixelshuffle_shape, 3, 1, 1)) # type: ignore
|
|
||||||
# upsample
|
|
||||||
self.upsampler = nn.PixelShuffle(self.scale)
|
|
||||||
|
|
||||||
self.load_state_dict(self.state, strict=False)
|
|
||||||
|
|
||||||
def get_num_conv(self) -> int:
|
|
||||||
return (int(self.key_arr[-1].split(".")[1]) - 2) // 2
|
|
||||||
|
|
||||||
def get_num_feats(self) -> int:
|
|
||||||
return self.state[self.key_arr[0]].shape[0]
|
|
||||||
|
|
||||||
def get_in_nc(self) -> int:
|
|
||||||
return self.state[self.key_arr[0]].shape[1]
|
|
||||||
|
|
||||||
def get_scale(self) -> int:
|
|
||||||
self.pixelshuffle_shape = self.state[self.key_arr[-1]].shape[0]
|
|
||||||
# Assume out_nc is the same as in_nc
|
|
||||||
# I cant think of a better way to do that
|
|
||||||
self.out_nc = self.in_nc
|
|
||||||
scale = math.sqrt(self.pixelshuffle_shape / self.out_nc)
|
|
||||||
if scale - int(scale) > 0:
|
|
||||||
print(
|
|
||||||
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
|
||||||
)
|
|
||||||
scale = int(scale)
|
|
||||||
return scale
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = x
|
|
||||||
for i in range(0, len(self.body)):
|
|
||||||
out = self.body[i](out)
|
|
||||||
|
|
||||||
out = self.upsampler(out)
|
|
||||||
# add the nearest upsampled image, so that the network learns the residual
|
|
||||||
base = F.interpolate(x, scale_factor=self.scale, mode="nearest")
|
|
||||||
out += base
|
|
||||||
return out
|
|
||||||
@ -1,161 +0,0 @@
|
|||||||
# From https://github.com/Koushik0901/Swift-SRGAN/blob/master/swift-srgan/models.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class SeperableConv2d(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True
|
|
||||||
):
|
|
||||||
super(SeperableConv2d, self).__init__()
|
|
||||||
self.depthwise = nn.Conv2d(
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
groups=in_channels,
|
|
||||||
bias=bias,
|
|
||||||
padding=padding,
|
|
||||||
)
|
|
||||||
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.pointwise(self.depthwise(x))
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
use_act=True,
|
|
||||||
use_bn=True,
|
|
||||||
discriminator=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super(ConvBlock, self).__init__()
|
|
||||||
|
|
||||||
self.use_act = use_act
|
|
||||||
self.cnn = SeperableConv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
|
|
||||||
self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
|
|
||||||
self.act = (
|
|
||||||
nn.LeakyReLU(0.2, inplace=True)
|
|
||||||
if discriminator
|
|
||||||
else nn.PReLU(num_parameters=out_channels)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
|
|
||||||
|
|
||||||
|
|
||||||
class UpsampleBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, scale_factor):
|
|
||||||
super(UpsampleBlock, self).__init__()
|
|
||||||
|
|
||||||
self.conv = SeperableConv2d(
|
|
||||||
in_channels,
|
|
||||||
in_channels * scale_factor**2,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.ps = nn.PixelShuffle(
|
|
||||||
scale_factor
|
|
||||||
) # (in_channels * 4, H, W) -> (in_channels, H*2, W*2)
|
|
||||||
self.act = nn.PReLU(num_parameters=in_channels)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.act(self.ps(self.conv(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super(ResidualBlock, self).__init__()
|
|
||||||
|
|
||||||
self.block1 = ConvBlock(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
self.block2 = ConvBlock(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.block1(x)
|
|
||||||
out = self.block2(out)
|
|
||||||
return out + x
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
"""Swift-SRGAN Generator
|
|
||||||
Args:
|
|
||||||
in_channels (int): number of input image channels.
|
|
||||||
num_channels (int): number of hidden channels.
|
|
||||||
num_blocks (int): number of residual blocks.
|
|
||||||
upscale_factor (int): factor to upscale the image [2x, 4x, 8x].
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: super resolution image
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
):
|
|
||||||
super(Generator, self).__init__()
|
|
||||||
self.model_arch = "Swift-SRGAN"
|
|
||||||
self.sub_type = "SR"
|
|
||||||
self.state = state_dict
|
|
||||||
if "model" in self.state:
|
|
||||||
self.state = self.state["model"]
|
|
||||||
|
|
||||||
self.in_nc: int = self.state["initial.cnn.depthwise.weight"].shape[0]
|
|
||||||
self.out_nc: int = self.state["final_conv.pointwise.weight"].shape[0]
|
|
||||||
self.num_filters: int = self.state["initial.cnn.pointwise.weight"].shape[0]
|
|
||||||
self.num_blocks = len(
|
|
||||||
set([x.split(".")[1] for x in self.state.keys() if "residual" in x])
|
|
||||||
)
|
|
||||||
self.scale: int = 2 ** len(
|
|
||||||
set([x.split(".")[1] for x in self.state.keys() if "upsampler" in x])
|
|
||||||
)
|
|
||||||
|
|
||||||
in_channels = self.in_nc
|
|
||||||
num_channels = self.num_filters
|
|
||||||
num_blocks = self.num_blocks
|
|
||||||
upscale_factor = self.scale
|
|
||||||
|
|
||||||
self.supports_fp16 = True
|
|
||||||
self.supports_bfp16 = True
|
|
||||||
self.min_size_restriction = None
|
|
||||||
|
|
||||||
self.initial = ConvBlock(
|
|
||||||
in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False
|
|
||||||
)
|
|
||||||
self.residual = nn.Sequential(
|
|
||||||
*[ResidualBlock(num_channels) for _ in range(num_blocks)]
|
|
||||||
)
|
|
||||||
self.convblock = ConvBlock(
|
|
||||||
num_channels,
|
|
||||||
num_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
use_act=False,
|
|
||||||
)
|
|
||||||
self.upsampler = nn.Sequential(
|
|
||||||
*[
|
|
||||||
UpsampleBlock(num_channels, scale_factor=2)
|
|
||||||
for _ in range(upscale_factor // 2)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.final_conv = SeperableConv2d(
|
|
||||||
num_channels, in_channels, kernel_size=9, stride=1, padding=4
|
|
||||||
)
|
|
||||||
|
|
||||||
self.load_state_dict(self.state, strict=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
initial = self.initial(x)
|
|
||||||
x = self.residual(initial)
|
|
||||||
x = self.convblock(x) + initial
|
|
||||||
x = self.upsampler(x)
|
|
||||||
return (torch.tanh(self.final_conv(x)) + 1) / 2
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,546 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
try:
|
|
||||||
from typing import Literal
|
|
||||||
except ImportError:
|
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
####################
|
|
||||||
# Basic blocks
|
|
||||||
####################
|
|
||||||
|
|
||||||
|
|
||||||
def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1):
|
|
||||||
# helper selecting activation
|
|
||||||
# neg_slope: for leakyrelu and init of prelu
|
|
||||||
# n_prelu: for p_relu num_parameters
|
|
||||||
act_type = act_type.lower()
|
|
||||||
if act_type == "relu":
|
|
||||||
layer = nn.ReLU(inplace)
|
|
||||||
elif act_type == "leakyrelu":
|
|
||||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
|
||||||
elif act_type == "prelu":
|
|
||||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"activation layer [{:s}] is not found".format(act_type)
|
|
||||||
)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def norm(norm_type: str, nc: int):
|
|
||||||
# helper selecting normalization layer
|
|
||||||
norm_type = norm_type.lower()
|
|
||||||
if norm_type == "batch":
|
|
||||||
layer = nn.BatchNorm2d(nc, affine=True)
|
|
||||||
elif norm_type == "instance":
|
|
||||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"normalization layer [{:s}] is not found".format(norm_type)
|
|
||||||
)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def pad(pad_type: str, padding):
|
|
||||||
# helper selecting padding layer
|
|
||||||
# if padding is 'zero', do by conv layers
|
|
||||||
pad_type = pad_type.lower()
|
|
||||||
if padding == 0:
|
|
||||||
return None
|
|
||||||
if pad_type == "reflect":
|
|
||||||
layer = nn.ReflectionPad2d(padding)
|
|
||||||
elif pad_type == "replicate":
|
|
||||||
layer = nn.ReplicationPad2d(padding)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"padding layer [{:s}] is not implemented".format(pad_type)
|
|
||||||
)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def get_valid_padding(kernel_size, dilation):
|
|
||||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
|
||||||
padding = (kernel_size - 1) // 2
|
|
||||||
return padding
|
|
||||||
|
|
||||||
|
|
||||||
class ConcatBlock(nn.Module):
|
|
||||||
# Concat the output of a submodule to its input
|
|
||||||
def __init__(self, submodule):
|
|
||||||
super(ConcatBlock, self).__init__()
|
|
||||||
self.sub = submodule
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
output = torch.cat((x, self.sub(x)), dim=1)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
tmpstr = "Identity .. \n|"
|
|
||||||
modstr = self.sub.__repr__().replace("\n", "\n|")
|
|
||||||
tmpstr = tmpstr + modstr
|
|
||||||
return tmpstr
|
|
||||||
|
|
||||||
|
|
||||||
class ShortcutBlock(nn.Module):
|
|
||||||
# Elementwise sum the output of a submodule to its input
|
|
||||||
def __init__(self, submodule):
|
|
||||||
super(ShortcutBlock, self).__init__()
|
|
||||||
self.sub = submodule
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
output = x + self.sub(x)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
tmpstr = "Identity + \n|"
|
|
||||||
modstr = self.sub.__repr__().replace("\n", "\n|")
|
|
||||||
tmpstr = tmpstr + modstr
|
|
||||||
return tmpstr
|
|
||||||
|
|
||||||
|
|
||||||
class ShortcutBlockSPSR(nn.Module):
|
|
||||||
# Elementwise sum the output of a submodule to its input
|
|
||||||
def __init__(self, submodule):
|
|
||||||
super(ShortcutBlockSPSR, self).__init__()
|
|
||||||
self.sub = submodule
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x, self.sub
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
tmpstr = "Identity + \n|"
|
|
||||||
modstr = self.sub.__repr__().replace("\n", "\n|")
|
|
||||||
tmpstr = tmpstr + modstr
|
|
||||||
return tmpstr
|
|
||||||
|
|
||||||
|
|
||||||
def sequential(*args):
|
|
||||||
# Flatten Sequential. It unwraps nn.Sequential.
|
|
||||||
if len(args) == 1:
|
|
||||||
if isinstance(args[0], OrderedDict):
|
|
||||||
raise NotImplementedError("sequential does not support OrderedDict input.")
|
|
||||||
return args[0] # No sequential is needed.
|
|
||||||
modules = []
|
|
||||||
for module in args:
|
|
||||||
if isinstance(module, nn.Sequential):
|
|
||||||
for submodule in module.children():
|
|
||||||
modules.append(submodule)
|
|
||||||
elif isinstance(module, nn.Module):
|
|
||||||
modules.append(module)
|
|
||||||
return nn.Sequential(*modules)
|
|
||||||
|
|
||||||
|
|
||||||
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
|
||||||
|
|
||||||
|
|
||||||
# 2x2x2 Conv Block
|
|
||||||
def conv_block_2c2(
|
|
||||||
in_nc,
|
|
||||||
out_nc,
|
|
||||||
act_type="relu",
|
|
||||||
):
|
|
||||||
return sequential(
|
|
||||||
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
|
|
||||||
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
|
|
||||||
act(act_type) if act_type else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_block(
|
|
||||||
in_nc: int,
|
|
||||||
out_nc: int,
|
|
||||||
kernel_size,
|
|
||||||
stride=1,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type: str | None = None,
|
|
||||||
act_type: str | None = "relu",
|
|
||||||
mode: ConvMode = "CNA",
|
|
||||||
c2x2=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Conv layer with padding, normalization, activation
|
|
||||||
mode: CNA --> Conv -> Norm -> Act
|
|
||||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if c2x2:
|
|
||||||
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
|
|
||||||
|
|
||||||
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
|
||||||
padding = get_valid_padding(kernel_size, dilation)
|
|
||||||
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
|
||||||
padding = padding if pad_type == "zero" else 0
|
|
||||||
|
|
||||||
c = nn.Conv2d(
|
|
||||||
in_nc,
|
|
||||||
out_nc,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
dilation=dilation,
|
|
||||||
bias=bias,
|
|
||||||
groups=groups,
|
|
||||||
)
|
|
||||||
a = act(act_type) if act_type else None
|
|
||||||
if mode in ("CNA", "CNAC"):
|
|
||||||
n = norm(norm_type, out_nc) if norm_type else None
|
|
||||||
return sequential(p, c, n, a)
|
|
||||||
elif mode == "NAC":
|
|
||||||
if norm_type is None and act_type is not None:
|
|
||||||
a = act(act_type, inplace=False)
|
|
||||||
# Important!
|
|
||||||
# input----ReLU(inplace)----Conv--+----output
|
|
||||||
# |________________________|
|
|
||||||
# inplace ReLU will modify the input, therefore wrong output
|
|
||||||
n = norm(norm_type, in_nc) if norm_type else None
|
|
||||||
return sequential(n, a, p, c)
|
|
||||||
else:
|
|
||||||
assert False, f"Invalid conv mode {mode}"
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
|
||||||
# Useful blocks
|
|
||||||
####################
|
|
||||||
|
|
||||||
|
|
||||||
class ResNetBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
ResNet Block, 3-3 style
|
|
||||||
with extra residual scaling used in EDSR
|
|
||||||
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_nc,
|
|
||||||
mid_nc,
|
|
||||||
out_nc,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
dilation=1,
|
|
||||||
groups=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=None,
|
|
||||||
act_type="relu",
|
|
||||||
mode: ConvMode = "CNA",
|
|
||||||
res_scale=1,
|
|
||||||
):
|
|
||||||
super(ResNetBlock, self).__init__()
|
|
||||||
conv0 = conv_block(
|
|
||||||
in_nc,
|
|
||||||
mid_nc,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
groups,
|
|
||||||
bias,
|
|
||||||
pad_type,
|
|
||||||
norm_type,
|
|
||||||
act_type,
|
|
||||||
mode,
|
|
||||||
)
|
|
||||||
if mode == "CNA":
|
|
||||||
act_type = None
|
|
||||||
if mode == "CNAC": # Residual path: |-CNAC-|
|
|
||||||
act_type = None
|
|
||||||
norm_type = None
|
|
||||||
conv1 = conv_block(
|
|
||||||
mid_nc,
|
|
||||||
out_nc,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
groups,
|
|
||||||
bias,
|
|
||||||
pad_type,
|
|
||||||
norm_type,
|
|
||||||
act_type,
|
|
||||||
mode,
|
|
||||||
)
|
|
||||||
# if in_nc != out_nc:
|
|
||||||
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
|
|
||||||
# None, None)
|
|
||||||
# print('Need a projecter in ResNetBlock.')
|
|
||||||
# else:
|
|
||||||
# self.project = lambda x:x
|
|
||||||
self.res = sequential(conv0, conv1)
|
|
||||||
self.res_scale = res_scale
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
res = self.res(x).mul(self.res_scale)
|
|
||||||
return x + res
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
|
||||||
"""
|
|
||||||
Residual in Residual Dense Block
|
|
||||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
nf,
|
|
||||||
kernel_size=3,
|
|
||||||
gc=32,
|
|
||||||
stride=1,
|
|
||||||
bias: bool = True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=None,
|
|
||||||
act_type="leakyrelu",
|
|
||||||
mode: ConvMode = "CNA",
|
|
||||||
_convtype="Conv2D",
|
|
||||||
_spectral_norm=False,
|
|
||||||
plus=False,
|
|
||||||
c2x2=False,
|
|
||||||
):
|
|
||||||
super(RRDB, self).__init__()
|
|
||||||
self.RDB1 = ResidualDenseBlock_5C(
|
|
||||||
nf,
|
|
||||||
kernel_size,
|
|
||||||
gc,
|
|
||||||
stride,
|
|
||||||
bias,
|
|
||||||
pad_type,
|
|
||||||
norm_type,
|
|
||||||
act_type,
|
|
||||||
mode,
|
|
||||||
plus=plus,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
self.RDB2 = ResidualDenseBlock_5C(
|
|
||||||
nf,
|
|
||||||
kernel_size,
|
|
||||||
gc,
|
|
||||||
stride,
|
|
||||||
bias,
|
|
||||||
pad_type,
|
|
||||||
norm_type,
|
|
||||||
act_type,
|
|
||||||
mode,
|
|
||||||
plus=plus,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
self.RDB3 = ResidualDenseBlock_5C(
|
|
||||||
nf,
|
|
||||||
kernel_size,
|
|
||||||
gc,
|
|
||||||
stride,
|
|
||||||
bias,
|
|
||||||
pad_type,
|
|
||||||
norm_type,
|
|
||||||
act_type,
|
|
||||||
mode,
|
|
||||||
plus=plus,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.RDB1(x)
|
|
||||||
out = self.RDB2(out)
|
|
||||||
out = self.RDB3(out)
|
|
||||||
return out * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
|
||||||
"""
|
|
||||||
Residual Dense Block
|
|
||||||
style: 5 convs
|
|
||||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
|
||||||
Modified options that can be used:
|
|
||||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
|
||||||
- "Spectral normalization" arXiv:1802.05957
|
|
||||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
|
||||||
{Rakotonirina} and A. {Rasoanaivo}
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nf (int): Channel number of intermediate features (num_feat).
|
|
||||||
gc (int): Channels for each growth (num_grow_ch: growth channel,
|
|
||||||
i.e. intermediate channels).
|
|
||||||
convtype (str): the type of convolution to use. Default: 'Conv2D'
|
|
||||||
gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
|
|
||||||
trainable parameters)
|
|
||||||
plus (bool): enable the additional residual paths from ESRGAN+
|
|
||||||
(adds trainable parameters)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
nf=64,
|
|
||||||
kernel_size=3,
|
|
||||||
gc=32,
|
|
||||||
stride=1,
|
|
||||||
bias: bool = True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type=None,
|
|
||||||
act_type="leakyrelu",
|
|
||||||
mode: ConvMode = "CNA",
|
|
||||||
plus=False,
|
|
||||||
c2x2=False,
|
|
||||||
):
|
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
|
||||||
|
|
||||||
## +
|
|
||||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
|
||||||
## +
|
|
||||||
|
|
||||||
self.conv1 = conv_block(
|
|
||||||
nf,
|
|
||||||
gc,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
bias=bias,
|
|
||||||
pad_type=pad_type,
|
|
||||||
norm_type=norm_type,
|
|
||||||
act_type=act_type,
|
|
||||||
mode=mode,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
self.conv2 = conv_block(
|
|
||||||
nf + gc,
|
|
||||||
gc,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
bias=bias,
|
|
||||||
pad_type=pad_type,
|
|
||||||
norm_type=norm_type,
|
|
||||||
act_type=act_type,
|
|
||||||
mode=mode,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
self.conv3 = conv_block(
|
|
||||||
nf + 2 * gc,
|
|
||||||
gc,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
bias=bias,
|
|
||||||
pad_type=pad_type,
|
|
||||||
norm_type=norm_type,
|
|
||||||
act_type=act_type,
|
|
||||||
mode=mode,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
self.conv4 = conv_block(
|
|
||||||
nf + 3 * gc,
|
|
||||||
gc,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
bias=bias,
|
|
||||||
pad_type=pad_type,
|
|
||||||
norm_type=norm_type,
|
|
||||||
act_type=act_type,
|
|
||||||
mode=mode,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
if mode == "CNA":
|
|
||||||
last_act = None
|
|
||||||
else:
|
|
||||||
last_act = act_type
|
|
||||||
self.conv5 = conv_block(
|
|
||||||
nf + 4 * gc,
|
|
||||||
nf,
|
|
||||||
3,
|
|
||||||
stride,
|
|
||||||
bias=bias,
|
|
||||||
pad_type=pad_type,
|
|
||||||
norm_type=norm_type,
|
|
||||||
act_type=last_act,
|
|
||||||
mode=mode,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.conv1(x)
|
|
||||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
|
||||||
if self.conv1x1:
|
|
||||||
# pylint: disable=not-callable
|
|
||||||
x2 = x2 + self.conv1x1(x) # +
|
|
||||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
|
||||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
|
||||||
if self.conv1x1:
|
|
||||||
x4 = x4 + x2 # +
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
return x5 * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
|
||||||
# Upsampler
|
|
||||||
####################
|
|
||||||
|
|
||||||
|
|
||||||
def pixelshuffle_block(
|
|
||||||
in_nc: int,
|
|
||||||
out_nc: int,
|
|
||||||
upscale_factor=2,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type: str | None = None,
|
|
||||||
act_type="relu",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Pixel shuffle layer
|
|
||||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
|
||||||
Neural Network, CVPR17)
|
|
||||||
"""
|
|
||||||
conv = conv_block(
|
|
||||||
in_nc,
|
|
||||||
out_nc * (upscale_factor**2),
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
bias=bias,
|
|
||||||
pad_type=pad_type,
|
|
||||||
norm_type=None,
|
|
||||||
act_type=None,
|
|
||||||
)
|
|
||||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
|
||||||
|
|
||||||
n = norm(norm_type, out_nc) if norm_type else None
|
|
||||||
a = act(act_type) if act_type else None
|
|
||||||
return sequential(conv, pixel_shuffle, n, a)
|
|
||||||
|
|
||||||
|
|
||||||
def upconv_block(
|
|
||||||
in_nc: int,
|
|
||||||
out_nc: int,
|
|
||||||
upscale_factor=2,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
bias=True,
|
|
||||||
pad_type="zero",
|
|
||||||
norm_type: str | None = None,
|
|
||||||
act_type="relu",
|
|
||||||
mode="nearest",
|
|
||||||
c2x2=False,
|
|
||||||
):
|
|
||||||
# Up conv
|
|
||||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
|
||||||
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
|
||||||
conv = conv_block(
|
|
||||||
in_nc,
|
|
||||||
out_nc,
|
|
||||||
kernel_size,
|
|
||||||
stride,
|
|
||||||
bias=bias,
|
|
||||||
pad_type=pad_type,
|
|
||||||
norm_type=norm_type,
|
|
||||||
act_type=act_type,
|
|
||||||
c2x2=c2x2,
|
|
||||||
)
|
|
||||||
return sequential(upsample, conv)
|
|
||||||
@ -1,351 +0,0 @@
|
|||||||
Tencent is pleased to support the open source community by making GFPGAN available.
|
|
||||||
|
|
||||||
Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
|
|
||||||
|
|
||||||
GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
|
|
||||||
|
|
||||||
|
|
||||||
Terms of the Apache License Version 2.0:
|
|
||||||
---------------------------------------------
|
|
||||||
Apache License
|
|
||||||
|
|
||||||
Version 2.0, January 2004
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
|
|
||||||
|
|
||||||
“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
|
||||||
|
|
||||||
“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
|
||||||
|
|
||||||
“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
|
|
||||||
|
|
||||||
“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
|
||||||
|
|
||||||
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
|
||||||
|
|
||||||
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
|
||||||
|
|
||||||
If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Other dependencies and licenses:
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
|
|
||||||
---------------------------------------------
|
|
||||||
1. basicsr
|
|
||||||
Copyright 2018-2020 BasicSR Authors
|
|
||||||
|
|
||||||
|
|
||||||
This BasicSR project is released under the Apache 2.0 license.
|
|
||||||
|
|
||||||
A copy of Apache 2.0 is included in this file.
|
|
||||||
|
|
||||||
StyleGAN2
|
|
||||||
The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
|
|
||||||
The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
|
|
||||||
DFDNet
|
|
||||||
The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
|
|
||||||
|
|
||||||
Terms of the Nvidia License:
|
|
||||||
---------------------------------------------
|
|
||||||
|
|
||||||
1. Definitions
|
|
||||||
|
|
||||||
"Licensor" means any person or entity that distributes its Work.
|
|
||||||
|
|
||||||
"Software" means the original work of authorship made available under
|
|
||||||
this License.
|
|
||||||
|
|
||||||
"Work" means the Software and any additions to or derivative works of
|
|
||||||
the Software that are made available under this License.
|
|
||||||
|
|
||||||
"Nvidia Processors" means any central processing unit (CPU), graphics
|
|
||||||
processing unit (GPU), field-programmable gate array (FPGA),
|
|
||||||
application-specific integrated circuit (ASIC) or any combination
|
|
||||||
thereof designed, made, sold, or provided by Nvidia or its affiliates.
|
|
||||||
|
|
||||||
The terms "reproduce," "reproduction," "derivative works," and
|
|
||||||
"distribution" have the meaning as provided under U.S. copyright law;
|
|
||||||
provided, however, that for the purposes of this License, derivative
|
|
||||||
works shall not include works that remain separable from, or merely
|
|
||||||
link (or bind by name) to the interfaces of, the Work.
|
|
||||||
|
|
||||||
Works, including the Software, are "made available" under this License
|
|
||||||
by including in or with the Work either (a) a copyright notice
|
|
||||||
referencing the applicability of this License to the Work, or (b) a
|
|
||||||
copy of this License.
|
|
||||||
|
|
||||||
2. License Grants
|
|
||||||
|
|
||||||
2.1 Copyright Grant. Subject to the terms and conditions of this
|
|
||||||
License, each Licensor grants to you a perpetual, worldwide,
|
|
||||||
non-exclusive, royalty-free, copyright license to reproduce,
|
|
||||||
prepare derivative works of, publicly display, publicly perform,
|
|
||||||
sublicense and distribute its Work and any resulting derivative
|
|
||||||
works in any form.
|
|
||||||
|
|
||||||
3. Limitations
|
|
||||||
|
|
||||||
3.1 Redistribution. You may reproduce or distribute the Work only
|
|
||||||
if (a) you do so under this License, (b) you include a complete
|
|
||||||
copy of this License with your distribution, and (c) you retain
|
|
||||||
without modification any copyright, patent, trademark, or
|
|
||||||
attribution notices that are present in the Work.
|
|
||||||
|
|
||||||
3.2 Derivative Works. You may specify that additional or different
|
|
||||||
terms apply to the use, reproduction, and distribution of your
|
|
||||||
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
|
||||||
provide that the use limitation in Section 3.3 applies to your
|
|
||||||
derivative works, and (b) you identify the specific derivative
|
|
||||||
works that are subject to Your Terms. Notwithstanding Your Terms,
|
|
||||||
this License (including the redistribution requirements in Section
|
|
||||||
3.1) will continue to apply to the Work itself.
|
|
||||||
|
|
||||||
3.3 Use Limitation. The Work and any derivative works thereof only
|
|
||||||
may be used or intended for use non-commercially. The Work or
|
|
||||||
derivative works thereof may be used or intended for use by Nvidia
|
|
||||||
or its affiliates commercially or non-commercially. As used herein,
|
|
||||||
"non-commercially" means for research or evaluation purposes only.
|
|
||||||
|
|
||||||
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
|
||||||
against any Licensor (including any claim, cross-claim or
|
|
||||||
counterclaim in a lawsuit) to enforce any patents that you allege
|
|
||||||
are infringed by any Work, then your rights under this License from
|
|
||||||
such Licensor (including the grants in Sections 2.1 and 2.2) will
|
|
||||||
terminate immediately.
|
|
||||||
|
|
||||||
3.5 Trademarks. This License does not grant any rights to use any
|
|
||||||
Licensor's or its affiliates' names, logos, or trademarks, except
|
|
||||||
as necessary to reproduce the notices described in this License.
|
|
||||||
|
|
||||||
3.6 Termination. If you violate any term of this License, then your
|
|
||||||
rights under this License (including the grants in Sections 2.1 and
|
|
||||||
2.2) will terminate immediately.
|
|
||||||
|
|
||||||
4. Disclaimer of Warranty.
|
|
||||||
|
|
||||||
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
|
||||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
|
||||||
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
|
||||||
THIS LICENSE.
|
|
||||||
|
|
||||||
5. Limitation of Liability.
|
|
||||||
|
|
||||||
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
|
||||||
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
|
||||||
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
|
||||||
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
|
||||||
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
|
||||||
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
|
||||||
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
|
||||||
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
|
||||||
THE POSSIBILITY OF SUCH DAMAGES.
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2019 Kim Seonghyeon
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the BSD 3-Clause license:
|
|
||||||
---------------------------------------------
|
|
||||||
1. torchvision
|
|
||||||
Copyright (c) Soumith Chintala 2016,
|
|
||||||
All rights reserved.
|
|
||||||
|
|
||||||
2. torch
|
|
||||||
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
|
||||||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
|
||||||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
|
||||||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
|
||||||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
|
||||||
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
|
||||||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
|
||||||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
|
||||||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
|
||||||
|
|
||||||
|
|
||||||
Terms of the BSD 3-Clause License:
|
|
||||||
---------------------------------------------
|
|
||||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
|
||||||
|
|
||||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
|
||||||
---------------------------------------------
|
|
||||||
1. numpy
|
|
||||||
Copyright (c) 2005-2020, NumPy Developers.
|
|
||||||
All rights reserved.
|
|
||||||
|
|
||||||
A copy of BSD 3-Clause License is included in this file.
|
|
||||||
|
|
||||||
The NumPy repository and source distributions bundle several libraries that are
|
|
||||||
compatibly licensed. We list these here.
|
|
||||||
|
|
||||||
Name: Numpydoc
|
|
||||||
Files: doc/sphinxext/numpydoc/*
|
|
||||||
License: BSD-2-Clause
|
|
||||||
For details, see doc/sphinxext/LICENSE.txt
|
|
||||||
|
|
||||||
Name: scipy-sphinx-theme
|
|
||||||
Files: doc/scipy-sphinx-theme/*
|
|
||||||
License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
|
|
||||||
For details, see doc/scipy-sphinx-theme/LICENSE.txt
|
|
||||||
|
|
||||||
Name: lapack-lite
|
|
||||||
Files: numpy/linalg/lapack_lite/*
|
|
||||||
License: BSD-3-Clause
|
|
||||||
For details, see numpy/linalg/lapack_lite/LICENSE.txt
|
|
||||||
|
|
||||||
Name: tempita
|
|
||||||
Files: tools/npy_tempita/*
|
|
||||||
License: MIT
|
|
||||||
For details, see tools/npy_tempita/license.txt
|
|
||||||
|
|
||||||
Name: dragon4
|
|
||||||
Files: numpy/core/src/multiarray/dragon4.c
|
|
||||||
License: MIT
|
|
||||||
For license text, see numpy/core/src/multiarray/dragon4.c
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the MIT license:
|
|
||||||
---------------------------------------------
|
|
||||||
1. facexlib
|
|
||||||
Copyright (c) 2020 Xintao Wang
|
|
||||||
|
|
||||||
2. opencv-python
|
|
||||||
Copyright (c) Olli-Pekka Heinisuo
|
|
||||||
Please note that only files in cv2 package are used.
|
|
||||||
|
|
||||||
|
|
||||||
Terms of the MIT License:
|
|
||||||
---------------------------------------------
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
|
|
||||||
---------------------------------------------
|
|
||||||
1. tqdm
|
|
||||||
Copyright (c) 2013 noamraph
|
|
||||||
|
|
||||||
`tqdm` is a product of collaborative work.
|
|
||||||
Unless otherwise stated, all authors (see commit logs) retain copyright
|
|
||||||
for their respective work, and release the work under the MIT licence
|
|
||||||
(text below).
|
|
||||||
|
|
||||||
Exceptions or notable authors are listed below
|
|
||||||
in reverse chronological order:
|
|
||||||
|
|
||||||
* files: *
|
|
||||||
MPLv2.0 2015-2020 (c) Casper da Costa-Luis
|
|
||||||
[casperdcl](https://github.com/casperdcl).
|
|
||||||
* files: tqdm/_tqdm.py
|
|
||||||
MIT 2016 (c) [PR #96] on behalf of Google Inc.
|
|
||||||
* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
|
|
||||||
MIT 2013 (c) Noam Yorav-Raphael, original author.
|
|
||||||
|
|
||||||
[PR #96]: https://github.com/tqdm/tqdm/pull/96
|
|
||||||
|
|
||||||
|
|
||||||
Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
|
|
||||||
-----------------------------------------------
|
|
||||||
|
|
||||||
This Source Code Form is subject to the terms of the
|
|
||||||
Mozilla Public License, v. 2.0.
|
|
||||||
If a copy of the MPL was not distributed with this file,
|
|
||||||
You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
||||||
|
|
||||||
|
|
||||||
MIT License (MIT)
|
|
||||||
-----------------
|
|
||||||
|
|
||||||
Copyright (c) 2013 noamraph
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
||||||
this software and associated documentation files (the "Software"), to deal in
|
|
||||||
the Software without restriction, including without limitation the rights to
|
|
||||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
||||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
|
||||||
subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
||||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
||||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
||||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
||||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
@ -1,351 +0,0 @@
|
|||||||
Tencent is pleased to support the open source community by making GFPGAN available.
|
|
||||||
|
|
||||||
Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
|
|
||||||
|
|
||||||
GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
|
|
||||||
|
|
||||||
|
|
||||||
Terms of the Apache License Version 2.0:
|
|
||||||
---------------------------------------------
|
|
||||||
Apache License
|
|
||||||
|
|
||||||
Version 2.0, January 2004
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
|
|
||||||
|
|
||||||
“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
|
||||||
|
|
||||||
“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
|
||||||
|
|
||||||
“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
|
|
||||||
|
|
||||||
“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
|
||||||
|
|
||||||
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
|
||||||
|
|
||||||
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
|
||||||
|
|
||||||
If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Other dependencies and licenses:
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
|
|
||||||
---------------------------------------------
|
|
||||||
1. basicsr
|
|
||||||
Copyright 2018-2020 BasicSR Authors
|
|
||||||
|
|
||||||
|
|
||||||
This BasicSR project is released under the Apache 2.0 license.
|
|
||||||
|
|
||||||
A copy of Apache 2.0 is included in this file.
|
|
||||||
|
|
||||||
StyleGAN2
|
|
||||||
The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
|
|
||||||
The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
|
|
||||||
DFDNet
|
|
||||||
The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
|
|
||||||
|
|
||||||
Terms of the Nvidia License:
|
|
||||||
---------------------------------------------
|
|
||||||
|
|
||||||
1. Definitions
|
|
||||||
|
|
||||||
"Licensor" means any person or entity that distributes its Work.
|
|
||||||
|
|
||||||
"Software" means the original work of authorship made available under
|
|
||||||
this License.
|
|
||||||
|
|
||||||
"Work" means the Software and any additions to or derivative works of
|
|
||||||
the Software that are made available under this License.
|
|
||||||
|
|
||||||
"Nvidia Processors" means any central processing unit (CPU), graphics
|
|
||||||
processing unit (GPU), field-programmable gate array (FPGA),
|
|
||||||
application-specific integrated circuit (ASIC) or any combination
|
|
||||||
thereof designed, made, sold, or provided by Nvidia or its affiliates.
|
|
||||||
|
|
||||||
The terms "reproduce," "reproduction," "derivative works," and
|
|
||||||
"distribution" have the meaning as provided under U.S. copyright law;
|
|
||||||
provided, however, that for the purposes of this License, derivative
|
|
||||||
works shall not include works that remain separable from, or merely
|
|
||||||
link (or bind by name) to the interfaces of, the Work.
|
|
||||||
|
|
||||||
Works, including the Software, are "made available" under this License
|
|
||||||
by including in or with the Work either (a) a copyright notice
|
|
||||||
referencing the applicability of this License to the Work, or (b) a
|
|
||||||
copy of this License.
|
|
||||||
|
|
||||||
2. License Grants
|
|
||||||
|
|
||||||
2.1 Copyright Grant. Subject to the terms and conditions of this
|
|
||||||
License, each Licensor grants to you a perpetual, worldwide,
|
|
||||||
non-exclusive, royalty-free, copyright license to reproduce,
|
|
||||||
prepare derivative works of, publicly display, publicly perform,
|
|
||||||
sublicense and distribute its Work and any resulting derivative
|
|
||||||
works in any form.
|
|
||||||
|
|
||||||
3. Limitations
|
|
||||||
|
|
||||||
3.1 Redistribution. You may reproduce or distribute the Work only
|
|
||||||
if (a) you do so under this License, (b) you include a complete
|
|
||||||
copy of this License with your distribution, and (c) you retain
|
|
||||||
without modification any copyright, patent, trademark, or
|
|
||||||
attribution notices that are present in the Work.
|
|
||||||
|
|
||||||
3.2 Derivative Works. You may specify that additional or different
|
|
||||||
terms apply to the use, reproduction, and distribution of your
|
|
||||||
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
|
||||||
provide that the use limitation in Section 3.3 applies to your
|
|
||||||
derivative works, and (b) you identify the specific derivative
|
|
||||||
works that are subject to Your Terms. Notwithstanding Your Terms,
|
|
||||||
this License (including the redistribution requirements in Section
|
|
||||||
3.1) will continue to apply to the Work itself.
|
|
||||||
|
|
||||||
3.3 Use Limitation. The Work and any derivative works thereof only
|
|
||||||
may be used or intended for use non-commercially. The Work or
|
|
||||||
derivative works thereof may be used or intended for use by Nvidia
|
|
||||||
or its affiliates commercially or non-commercially. As used herein,
|
|
||||||
"non-commercially" means for research or evaluation purposes only.
|
|
||||||
|
|
||||||
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
|
||||||
against any Licensor (including any claim, cross-claim or
|
|
||||||
counterclaim in a lawsuit) to enforce any patents that you allege
|
|
||||||
are infringed by any Work, then your rights under this License from
|
|
||||||
such Licensor (including the grants in Sections 2.1 and 2.2) will
|
|
||||||
terminate immediately.
|
|
||||||
|
|
||||||
3.5 Trademarks. This License does not grant any rights to use any
|
|
||||||
Licensor's or its affiliates' names, logos, or trademarks, except
|
|
||||||
as necessary to reproduce the notices described in this License.
|
|
||||||
|
|
||||||
3.6 Termination. If you violate any term of this License, then your
|
|
||||||
rights under this License (including the grants in Sections 2.1 and
|
|
||||||
2.2) will terminate immediately.
|
|
||||||
|
|
||||||
4. Disclaimer of Warranty.
|
|
||||||
|
|
||||||
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
|
||||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
|
||||||
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
|
||||||
THIS LICENSE.
|
|
||||||
|
|
||||||
5. Limitation of Liability.
|
|
||||||
|
|
||||||
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
|
||||||
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
|
||||||
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
|
||||||
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
|
||||||
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
|
||||||
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
|
||||||
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
|
||||||
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
|
||||||
THE POSSIBILITY OF SUCH DAMAGES.
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2019 Kim Seonghyeon
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the BSD 3-Clause license:
|
|
||||||
---------------------------------------------
|
|
||||||
1. torchvision
|
|
||||||
Copyright (c) Soumith Chintala 2016,
|
|
||||||
All rights reserved.
|
|
||||||
|
|
||||||
2. torch
|
|
||||||
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
|
||||||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
|
||||||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
|
||||||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
|
||||||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
|
||||||
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
|
||||||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
|
||||||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
|
||||||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
|
||||||
|
|
||||||
|
|
||||||
Terms of the BSD 3-Clause License:
|
|
||||||
---------------------------------------------
|
|
||||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
|
||||||
|
|
||||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
|
||||||
---------------------------------------------
|
|
||||||
1. numpy
|
|
||||||
Copyright (c) 2005-2020, NumPy Developers.
|
|
||||||
All rights reserved.
|
|
||||||
|
|
||||||
A copy of BSD 3-Clause License is included in this file.
|
|
||||||
|
|
||||||
The NumPy repository and source distributions bundle several libraries that are
|
|
||||||
compatibly licensed. We list these here.
|
|
||||||
|
|
||||||
Name: Numpydoc
|
|
||||||
Files: doc/sphinxext/numpydoc/*
|
|
||||||
License: BSD-2-Clause
|
|
||||||
For details, see doc/sphinxext/LICENSE.txt
|
|
||||||
|
|
||||||
Name: scipy-sphinx-theme
|
|
||||||
Files: doc/scipy-sphinx-theme/*
|
|
||||||
License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
|
|
||||||
For details, see doc/scipy-sphinx-theme/LICENSE.txt
|
|
||||||
|
|
||||||
Name: lapack-lite
|
|
||||||
Files: numpy/linalg/lapack_lite/*
|
|
||||||
License: BSD-3-Clause
|
|
||||||
For details, see numpy/linalg/lapack_lite/LICENSE.txt
|
|
||||||
|
|
||||||
Name: tempita
|
|
||||||
Files: tools/npy_tempita/*
|
|
||||||
License: MIT
|
|
||||||
For details, see tools/npy_tempita/license.txt
|
|
||||||
|
|
||||||
Name: dragon4
|
|
||||||
Files: numpy/core/src/multiarray/dragon4.c
|
|
||||||
License: MIT
|
|
||||||
For license text, see numpy/core/src/multiarray/dragon4.c
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the MIT license:
|
|
||||||
---------------------------------------------
|
|
||||||
1. facexlib
|
|
||||||
Copyright (c) 2020 Xintao Wang
|
|
||||||
|
|
||||||
2. opencv-python
|
|
||||||
Copyright (c) Olli-Pekka Heinisuo
|
|
||||||
Please note that only files in cv2 package are used.
|
|
||||||
|
|
||||||
|
|
||||||
Terms of the MIT License:
|
|
||||||
---------------------------------------------
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
|
|
||||||
---------------------------------------------
|
|
||||||
1. tqdm
|
|
||||||
Copyright (c) 2013 noamraph
|
|
||||||
|
|
||||||
`tqdm` is a product of collaborative work.
|
|
||||||
Unless otherwise stated, all authors (see commit logs) retain copyright
|
|
||||||
for their respective work, and release the work under the MIT licence
|
|
||||||
(text below).
|
|
||||||
|
|
||||||
Exceptions or notable authors are listed below
|
|
||||||
in reverse chronological order:
|
|
||||||
|
|
||||||
* files: *
|
|
||||||
MPLv2.0 2015-2020 (c) Casper da Costa-Luis
|
|
||||||
[casperdcl](https://github.com/casperdcl).
|
|
||||||
* files: tqdm/_tqdm.py
|
|
||||||
MIT 2016 (c) [PR #96] on behalf of Google Inc.
|
|
||||||
* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
|
|
||||||
MIT 2013 (c) Noam Yorav-Raphael, original author.
|
|
||||||
|
|
||||||
[PR #96]: https://github.com/tqdm/tqdm/pull/96
|
|
||||||
|
|
||||||
|
|
||||||
Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
|
|
||||||
-----------------------------------------------
|
|
||||||
|
|
||||||
This Source Code Form is subject to the terms of the
|
|
||||||
Mozilla Public License, v. 2.0.
|
|
||||||
If a copy of the MPL was not distributed with this file,
|
|
||||||
You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
||||||
|
|
||||||
|
|
||||||
MIT License (MIT)
|
|
||||||
-----------------
|
|
||||||
|
|
||||||
Copyright (c) 2013 noamraph
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
||||||
this software and associated documentation files (the "Software"), to deal in
|
|
||||||
the Software without restriction, including without limitation the rights to
|
|
||||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
||||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
|
||||||
subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
||||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
||||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
||||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
||||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
@ -1,35 +0,0 @@
|
|||||||
S-Lab License 1.0
|
|
||||||
|
|
||||||
Copyright 2022 S-Lab
|
|
||||||
|
|
||||||
Redistribution and use for non-commercial purpose in source and
|
|
||||||
binary forms, with or without modification, are permitted provided
|
|
||||||
that the following conditions are met:
|
|
||||||
|
|
||||||
1. Redistributions of source code must retain the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
2. Redistributions in binary form must reproduce the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer in
|
|
||||||
the documentation and/or other materials provided with the
|
|
||||||
distribution.
|
|
||||||
|
|
||||||
3. Neither the name of the copyright holder nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived
|
|
||||||
from this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
||||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
||||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
||||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
||||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
||||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
||||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
||||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
||||||
In the event that redistribution and/or use for commercial purpose in
|
|
||||||
source or binary forms, with or without modification is required,
|
|
||||||
please contact the contributor(s) of the work.
|
|
||||||
@ -1,265 +0,0 @@
|
|||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def conv3x3(inplanes, outplanes, stride=1):
|
|
||||||
"""A simple wrapper for 3x3 convolution with padding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inplanes (int): Channel number of inputs.
|
|
||||||
outplanes (int): Channel number of outputs.
|
|
||||||
stride (int): Stride in convolution. Default: 1.
|
|
||||||
"""
|
|
||||||
return nn.Conv2d(
|
|
||||||
inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Module):
|
|
||||||
"""Basic residual block used in the ResNetArcFace architecture.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inplanes (int): Channel number of inputs.
|
|
||||||
planes (int): Channel number of outputs.
|
|
||||||
stride (int): Stride in convolution. Default: 1.
|
|
||||||
downsample (nn.Module): The downsample module. Default: None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
expansion = 1 # output channel expansion ratio
|
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
|
||||||
super(BasicBlock, self).__init__()
|
|
||||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
|
||||||
self.bn1 = nn.BatchNorm2d(planes)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.conv2 = conv3x3(planes, planes)
|
|
||||||
self.bn2 = nn.BatchNorm2d(planes)
|
|
||||||
self.downsample = downsample
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
residual = x
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
residual = self.downsample(x)
|
|
||||||
|
|
||||||
out += residual
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class IRBlock(nn.Module):
|
|
||||||
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inplanes (int): Channel number of inputs.
|
|
||||||
planes (int): Channel number of outputs.
|
|
||||||
stride (int): Stride in convolution. Default: 1.
|
|
||||||
downsample (nn.Module): The downsample module. Default: None.
|
|
||||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
expansion = 1 # output channel expansion ratio
|
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
|
||||||
super(IRBlock, self).__init__()
|
|
||||||
self.bn0 = nn.BatchNorm2d(inplanes)
|
|
||||||
self.conv1 = conv3x3(inplanes, inplanes)
|
|
||||||
self.bn1 = nn.BatchNorm2d(inplanes)
|
|
||||||
self.prelu = nn.PReLU()
|
|
||||||
self.conv2 = conv3x3(inplanes, planes, stride)
|
|
||||||
self.bn2 = nn.BatchNorm2d(planes)
|
|
||||||
self.downsample = downsample
|
|
||||||
self.stride = stride
|
|
||||||
self.use_se = use_se
|
|
||||||
if self.use_se:
|
|
||||||
self.se = SEBlock(planes)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
residual = x
|
|
||||||
out = self.bn0(x)
|
|
||||||
out = self.conv1(out)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.prelu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
if self.use_se:
|
|
||||||
out = self.se(out)
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
residual = self.downsample(x)
|
|
||||||
|
|
||||||
out += residual
|
|
||||||
out = self.prelu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
|
||||||
"""Bottleneck block used in the ResNetArcFace architecture.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inplanes (int): Channel number of inputs.
|
|
||||||
planes (int): Channel number of outputs.
|
|
||||||
stride (int): Stride in convolution. Default: 1.
|
|
||||||
downsample (nn.Module): The downsample module. Default: None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
expansion = 4 # output channel expansion ratio
|
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
|
||||||
super(Bottleneck, self).__init__()
|
|
||||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(planes)
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
|
||||||
)
|
|
||||||
self.bn2 = nn.BatchNorm2d(planes)
|
|
||||||
self.conv3 = nn.Conv2d(
|
|
||||||
planes, planes * self.expansion, kernel_size=1, bias=False
|
|
||||||
)
|
|
||||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.downsample = downsample
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
residual = x
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv3(out)
|
|
||||||
out = self.bn3(out)
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
residual = self.downsample(x)
|
|
||||||
|
|
||||||
out += residual
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class SEBlock(nn.Module):
|
|
||||||
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel (int): Channel number of inputs.
|
|
||||||
reduction (int): Channel reduction ration. Default: 16.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, channel, reduction=16):
|
|
||||||
super(SEBlock, self).__init__()
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(
|
|
||||||
1
|
|
||||||
) # pool to 1x1 without spatial information
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(channel, channel // reduction),
|
|
||||||
nn.PReLU(),
|
|
||||||
nn.Linear(channel // reduction, channel),
|
|
||||||
nn.Sigmoid(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, _, _ = x.size()
|
|
||||||
y = self.avg_pool(x).view(b, c)
|
|
||||||
y = self.fc(y).view(b, c, 1, 1)
|
|
||||||
return x * y
|
|
||||||
|
|
||||||
|
|
||||||
class ResNetArcFace(nn.Module):
|
|
||||||
"""ArcFace with ResNet architectures.
|
|
||||||
|
|
||||||
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
block (str): Block used in the ArcFace architecture.
|
|
||||||
layers (tuple(int)): Block numbers in each layer.
|
|
||||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, block, layers, use_se=True):
|
|
||||||
if block == "IRBlock":
|
|
||||||
block = IRBlock
|
|
||||||
self.inplanes = 64
|
|
||||||
self.use_se = use_se
|
|
||||||
super(ResNetArcFace, self).__init__()
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(64)
|
|
||||||
self.prelu = nn.PReLU()
|
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
||||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
||||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
|
||||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
|
||||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
|
||||||
self.bn4 = nn.BatchNorm2d(512)
|
|
||||||
self.dropout = nn.Dropout()
|
|
||||||
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
|
||||||
self.bn5 = nn.BatchNorm1d(512)
|
|
||||||
|
|
||||||
# initialization
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
nn.init.xavier_normal_(m.weight)
|
|
||||||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
|
||||||
nn.init.constant_(m.weight, 1)
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
nn.init.xavier_normal_(m.weight)
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
def _make_layer(self, block, planes, num_blocks, stride=1):
|
|
||||||
downsample = None
|
|
||||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
||||||
downsample = nn.Sequential(
|
|
||||||
nn.Conv2d(
|
|
||||||
self.inplanes,
|
|
||||||
planes * block.expansion,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=stride,
|
|
||||||
bias=False,
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(planes * block.expansion),
|
|
||||||
)
|
|
||||||
layers = []
|
|
||||||
layers.append(
|
|
||||||
block(self.inplanes, planes, stride, downsample, use_se=self.use_se)
|
|
||||||
)
|
|
||||||
self.inplanes = planes
|
|
||||||
for _ in range(1, num_blocks):
|
|
||||||
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.bn1(x)
|
|
||||||
x = self.prelu(x)
|
|
||||||
x = self.maxpool(x)
|
|
||||||
|
|
||||||
x = self.layer1(x)
|
|
||||||
x = self.layer2(x)
|
|
||||||
x = self.layer3(x)
|
|
||||||
x = self.layer4(x)
|
|
||||||
x = self.bn4(x)
|
|
||||||
x = self.dropout(x)
|
|
||||||
x = x.view(x.size(0), -1)
|
|
||||||
x = self.fc5(x)
|
|
||||||
x = self.bn5(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
@ -1,790 +0,0 @@
|
|||||||
"""
|
|
||||||
Modified from https://github.com/sczhou/CodeFormer
|
|
||||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
|
||||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
|
||||||
This verison of the arch specifically was gathered from an old version of GFPGAN. If this is a problem, please contact me.
|
|
||||||
"""
|
|
||||||
import math
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import logging as logger
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class VectorQuantizer(nn.Module):
|
|
||||||
def __init__(self, codebook_size, emb_dim, beta):
|
|
||||||
super(VectorQuantizer, self).__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
|
||||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
|
||||||
self.embedding.weight.data.uniform_(
|
|
||||||
-1.0 / self.codebook_size, 1.0 / self.codebook_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
# reshape z -> (batch, height, width, channel) and flatten
|
|
||||||
z = z.permute(0, 2, 3, 1).contiguous()
|
|
||||||
z_flattened = z.view(-1, self.emb_dim)
|
|
||||||
|
|
||||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
|
||||||
d = (
|
|
||||||
(z_flattened**2).sum(dim=1, keepdim=True)
|
|
||||||
+ (self.embedding.weight**2).sum(1)
|
|
||||||
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
|
||||||
)
|
|
||||||
|
|
||||||
mean_distance = torch.mean(d)
|
|
||||||
# find closest encodings
|
|
||||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
|
||||||
min_encoding_scores, min_encoding_indices = torch.topk(
|
|
||||||
d, 1, dim=1, largest=False
|
|
||||||
)
|
|
||||||
# [0-1], higher score, higher confidence
|
|
||||||
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
|
|
||||||
|
|
||||||
min_encodings = torch.zeros(
|
|
||||||
min_encoding_indices.shape[0], self.codebook_size
|
|
||||||
).to(z)
|
|
||||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
|
||||||
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
|
||||||
# compute loss for embedding
|
|
||||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
|
||||||
(z_q - z.detach()) ** 2
|
|
||||||
)
|
|
||||||
# preserve gradients
|
|
||||||
z_q = z + (z_q - z).detach()
|
|
||||||
|
|
||||||
# perplexity
|
|
||||||
e_mean = torch.mean(min_encodings, dim=0)
|
|
||||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
|
||||||
# reshape back to match original input shape
|
|
||||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return (
|
|
||||||
z_q,
|
|
||||||
loss,
|
|
||||||
{
|
|
||||||
"perplexity": perplexity,
|
|
||||||
"min_encodings": min_encodings,
|
|
||||||
"min_encoding_indices": min_encoding_indices,
|
|
||||||
"min_encoding_scores": min_encoding_scores,
|
|
||||||
"mean_distance": mean_distance,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_codebook_feat(self, indices, shape):
|
|
||||||
# input indices: batch*token_num -> (batch*token_num)*1
|
|
||||||
# shape: batch, height, width, channel
|
|
||||||
indices = indices.view(-1, 1)
|
|
||||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
|
||||||
min_encodings.scatter_(1, indices, 1)
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
|
||||||
|
|
||||||
if shape is not None: # reshape back to match original input shape
|
|
||||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return z_q
|
|
||||||
|
|
||||||
|
|
||||||
class GumbelQuantizer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
codebook_size,
|
|
||||||
emb_dim,
|
|
||||||
num_hiddens,
|
|
||||||
straight_through=False,
|
|
||||||
kl_weight=5e-4,
|
|
||||||
temp_init=1.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.straight_through = straight_through
|
|
||||||
self.temperature = temp_init
|
|
||||||
self.kl_weight = kl_weight
|
|
||||||
self.proj = nn.Conv2d(
|
|
||||||
num_hiddens, codebook_size, 1
|
|
||||||
) # projects last encoder layer to quantized logits
|
|
||||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
hard = self.straight_through if self.training else True
|
|
||||||
|
|
||||||
logits = self.proj(z)
|
|
||||||
|
|
||||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
|
||||||
|
|
||||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
|
||||||
|
|
||||||
# + kl divergence to the prior loss
|
|
||||||
qy = F.softmax(logits, dim=1)
|
|
||||||
diff = (
|
|
||||||
self.kl_weight
|
|
||||||
* torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
|
||||||
)
|
|
||||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
|
||||||
|
|
||||||
return z_q, diff, {"min_encoding_indices": min_encoding_indices}
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
pad = (0, 1, 0, 1)
|
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
||||||
x = self.conv(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = normalize(in_channels)
|
|
||||||
self.q = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.k = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.v = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.proj_out = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b, c, h, w = q.shape
|
|
||||||
q = q.reshape(b, c, h * w)
|
|
||||||
q = q.permute(0, 2, 1)
|
|
||||||
k = k.reshape(b, c, h * w)
|
|
||||||
w_ = torch.bmm(q, k)
|
|
||||||
w_ = w_ * (int(c) ** (-0.5))
|
|
||||||
w_ = F.softmax(w_, dim=2)
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v = v.reshape(b, c, h * w)
|
|
||||||
w_ = w_.permute(0, 2, 1)
|
|
||||||
h_ = torch.bmm(v, w_)
|
|
||||||
h_ = h_.reshape(b, c, h, w)
|
|
||||||
|
|
||||||
h_ = self.proj_out(h_)
|
|
||||||
|
|
||||||
return x + h_
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
nf,
|
|
||||||
out_channels,
|
|
||||||
ch_mult,
|
|
||||||
num_res_blocks,
|
|
||||||
resolution,
|
|
||||||
attn_resolutions,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
|
|
||||||
curr_res = self.resolution
|
|
||||||
in_ch_mult = (1,) + tuple(ch_mult)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial convultion
|
|
||||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
|
||||||
|
|
||||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
|
||||||
for i in range(self.num_resolutions):
|
|
||||||
block_in_ch = nf * in_ch_mult[i]
|
|
||||||
block_out_ch = nf * ch_mult[i]
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
if curr_res in attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != self.num_resolutions - 1:
|
|
||||||
blocks.append(Downsample(block_in_ch))
|
|
||||||
curr_res = curr_res // 2
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
|
|
||||||
blocks.append(AttnBlock(block_in_ch)) # type: ignore
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
|
|
||||||
|
|
||||||
# normalise and convert to latent size
|
|
||||||
blocks.append(normalize(block_in_ch)) # type: ignore
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(block_in_ch, out_channels, kernel_size=3, stride=1, padding=1) # type: ignore
|
|
||||||
)
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
def __init__(self, nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.num_resolutions = len(self.ch_mult)
|
|
||||||
self.num_res_blocks = res_blocks
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
self.in_channels = emb_dim
|
|
||||||
self.out_channels = 3
|
|
||||||
block_in_ch = self.nf * self.ch_mult[-1]
|
|
||||||
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial conv
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
|
|
||||||
for i in reversed(range(self.num_resolutions)):
|
|
||||||
block_out_ch = self.nf * self.ch_mult[i]
|
|
||||||
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
|
|
||||||
if curr_res in self.attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != 0:
|
|
||||||
blocks.append(Upsample(block_in_ch))
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
|
|
||||||
blocks.append(normalize(block_in_ch))
|
|
||||||
blocks.append(
|
|
||||||
nn.Conv2d(
|
|
||||||
block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class VQAutoEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
img_size,
|
|
||||||
nf,
|
|
||||||
ch_mult,
|
|
||||||
quantizer="nearest",
|
|
||||||
res_blocks=2,
|
|
||||||
attn_resolutions=[16],
|
|
||||||
codebook_size=1024,
|
|
||||||
emb_dim=256,
|
|
||||||
beta=0.25,
|
|
||||||
gumbel_straight_through=False,
|
|
||||||
gumbel_kl_weight=1e-8,
|
|
||||||
model_path=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = 3
|
|
||||||
self.nf = nf
|
|
||||||
self.n_blocks = res_blocks
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.embed_dim = emb_dim
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
self.quantizer_type = quantizer
|
|
||||||
self.encoder = Encoder(
|
|
||||||
self.in_channels,
|
|
||||||
self.nf,
|
|
||||||
self.embed_dim,
|
|
||||||
self.ch_mult,
|
|
||||||
self.n_blocks,
|
|
||||||
self.resolution,
|
|
||||||
self.attn_resolutions,
|
|
||||||
)
|
|
||||||
if self.quantizer_type == "nearest":
|
|
||||||
self.beta = beta # 0.25
|
|
||||||
self.quantize = VectorQuantizer(
|
|
||||||
self.codebook_size, self.embed_dim, self.beta
|
|
||||||
)
|
|
||||||
elif self.quantizer_type == "gumbel":
|
|
||||||
self.gumbel_num_hiddens = emb_dim
|
|
||||||
self.straight_through = gumbel_straight_through
|
|
||||||
self.kl_weight = gumbel_kl_weight
|
|
||||||
self.quantize = GumbelQuantizer(
|
|
||||||
self.codebook_size,
|
|
||||||
self.embed_dim,
|
|
||||||
self.gumbel_num_hiddens,
|
|
||||||
self.straight_through,
|
|
||||||
self.kl_weight,
|
|
||||||
)
|
|
||||||
self.generator = Generator(
|
|
||||||
nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
chkpt = torch.load(model_path, map_location="cpu")
|
|
||||||
if "params_ema" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params_ema"]
|
|
||||||
)
|
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
|
|
||||||
elif "params" in chkpt:
|
|
||||||
self.load_state_dict(
|
|
||||||
torch.load(model_path, map_location="cpu")["params"]
|
|
||||||
)
|
|
||||||
logger.info(f"vqgan is loaded from: {model_path} [params]")
|
|
||||||
else:
|
|
||||||
raise ValueError("Wrong params!")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.encoder(x)
|
|
||||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
|
||||||
x = self.generator(quant)
|
|
||||||
return x, codebook_loss, quant_stats
|
|
||||||
|
|
||||||
|
|
||||||
def calc_mean_std(feat, eps=1e-5):
|
|
||||||
"""Calculate mean and std for adaptive_instance_normalization.
|
|
||||||
Args:
|
|
||||||
feat (Tensor): 4D tensor.
|
|
||||||
eps (float): A small value added to the variance to avoid
|
|
||||||
divide-by-zero. Default: 1e-5.
|
|
||||||
"""
|
|
||||||
size = feat.size()
|
|
||||||
assert len(size) == 4, "The input feature should be 4D tensor."
|
|
||||||
b, c = size[:2]
|
|
||||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
|
||||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
|
||||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
|
||||||
return feat_mean, feat_std
|
|
||||||
|
|
||||||
|
|
||||||
def adaptive_instance_normalization(content_feat, style_feat):
|
|
||||||
"""Adaptive instance normalization.
|
|
||||||
Adjust the reference features to have the similar color and illuminations
|
|
||||||
as those in the degradate features.
|
|
||||||
Args:
|
|
||||||
content_feat (Tensor): The reference feature.
|
|
||||||
style_feat (Tensor): The degradate features.
|
|
||||||
"""
|
|
||||||
size = content_feat.size()
|
|
||||||
style_mean, style_std = calc_mean_std(style_feat)
|
|
||||||
content_mean, content_std = calc_mean_std(content_feat)
|
|
||||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
|
|
||||||
size
|
|
||||||
)
|
|
||||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
|
||||||
"""
|
|
||||||
This is a more standard version of the position embedding, very similar to the one
|
|
||||||
used by the Attention is all you need paper, generalized to work on images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.num_pos_feats = num_pos_feats
|
|
||||||
self.temperature = temperature
|
|
||||||
self.normalize = normalize
|
|
||||||
if scale is not None and normalize is False:
|
|
||||||
raise ValueError("normalize should be True if scale is passed")
|
|
||||||
if scale is None:
|
|
||||||
scale = 2 * math.pi
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
|
||||||
if mask is None:
|
|
||||||
mask = torch.zeros(
|
|
||||||
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
|
||||||
)
|
|
||||||
not_mask = ~mask # pylint: disable=invalid-unary-operand-type
|
|
||||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
|
||||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
|
||||||
if self.normalize:
|
|
||||||
eps = 1e-6
|
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
|
||||||
pos_x = torch.stack(
|
|
||||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos_y = torch.stack(
|
|
||||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
"""Return an activation function given a string"""
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
if activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
if activation == "glu":
|
|
||||||
return F.glu
|
|
||||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerSALayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model - MLP
|
|
||||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(embed_dim)
|
|
||||||
self.norm2 = nn.LayerNorm(embed_dim)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
tgt,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None,
|
|
||||||
):
|
|
||||||
# self attention
|
|
||||||
tgt2 = self.norm1(tgt)
|
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
|
||||||
tgt2 = self.self_attn(
|
|
||||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
|
||||||
)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
|
|
||||||
# ffn
|
|
||||||
tgt2 = self.norm2(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(in_channels):
|
|
||||||
return torch.nn.GroupNorm(
|
|
||||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script # type: ignore
|
|
||||||
def swish(x):
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels=None):
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
self.norm1 = normalize(in_channels)
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
|
|
||||||
)
|
|
||||||
self.norm2 = normalize(out_channels)
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
|
|
||||||
)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
self.conv_out = nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0 # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x_in):
|
|
||||||
x = x_in
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
x_in = self.conv_out(x_in)
|
|
||||||
|
|
||||||
return x + x_in
|
|
||||||
|
|
||||||
|
|
||||||
class Fuse_sft_block(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.encode_enc = ResBlock(2 * in_ch, out_ch)
|
|
||||||
|
|
||||||
self.scale = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.shift = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, enc_feat, dec_feat, w=1):
|
|
||||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
|
||||||
scale = self.scale(enc_feat)
|
|
||||||
shift = self.shift(enc_feat)
|
|
||||||
residual = w * (dec_feat * scale + shift)
|
|
||||||
out = dec_feat + residual
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class CodeFormer(VQAutoEncoder):
|
|
||||||
def __init__(self, state_dict):
|
|
||||||
dim_embd = 512
|
|
||||||
n_head = 8
|
|
||||||
n_layers = 9
|
|
||||||
codebook_size = 1024
|
|
||||||
latent_size = 256
|
|
||||||
connect_list = ["32", "64", "128", "256"]
|
|
||||||
fix_modules = ["quantize", "generator"]
|
|
||||||
|
|
||||||
# This is just a guess as I only have one model to look at
|
|
||||||
position_emb = state_dict["position_emb"]
|
|
||||||
dim_embd = position_emb.shape[1]
|
|
||||||
latent_size = position_emb.shape[0]
|
|
||||||
|
|
||||||
try:
|
|
||||||
n_layers = len(
|
|
||||||
set([x.split(".")[1] for x in state_dict.keys() if "ft_layers" in x])
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
codebook_size = state_dict["quantize.embedding.weight"].shape[0]
|
|
||||||
|
|
||||||
# This is also just another guess
|
|
||||||
n_head_exp = (
|
|
||||||
state_dict["ft_layers.0.self_attn.in_proj_weight"].shape[0] // dim_embd
|
|
||||||
)
|
|
||||||
n_head = 2**n_head_exp
|
|
||||||
|
|
||||||
in_nc = state_dict["encoder.blocks.0.weight"].shape[1]
|
|
||||||
|
|
||||||
self.model_arch = "CodeFormer"
|
|
||||||
self.sub_type = "Face SR"
|
|
||||||
self.scale = 8
|
|
||||||
self.in_nc = in_nc
|
|
||||||
self.out_nc = in_nc
|
|
||||||
|
|
||||||
self.state = state_dict
|
|
||||||
|
|
||||||
self.supports_fp16 = False
|
|
||||||
self.supports_bf16 = True
|
|
||||||
self.min_size_restriction = 16
|
|
||||||
|
|
||||||
super(CodeFormer, self).__init__(
|
|
||||||
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if fix_modules is not None:
|
|
||||||
for module in fix_modules:
|
|
||||||
for param in getattr(self, module).parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
self.connect_list = connect_list
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.dim_embd = dim_embd
|
|
||||||
self.dim_mlp = dim_embd * 2
|
|
||||||
|
|
||||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) # type: ignore
|
|
||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
|
||||||
|
|
||||||
# transformer
|
|
||||||
self.ft_layers = nn.Sequential(
|
|
||||||
*[
|
|
||||||
TransformerSALayer(
|
|
||||||
embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
|
|
||||||
)
|
|
||||||
for _ in range(self.n_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# logits_predict head
|
|
||||||
self.idx_pred_layer = nn.Sequential(
|
|
||||||
nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.channels = {
|
|
||||||
"16": 512,
|
|
||||||
"32": 256,
|
|
||||||
"64": 256,
|
|
||||||
"128": 128,
|
|
||||||
"256": 128,
|
|
||||||
"512": 64,
|
|
||||||
}
|
|
||||||
|
|
||||||
# after second residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_encoder_block = {
|
|
||||||
"512": 2,
|
|
||||||
"256": 5,
|
|
||||||
"128": 8,
|
|
||||||
"64": 11,
|
|
||||||
"32": 14,
|
|
||||||
"16": 18,
|
|
||||||
}
|
|
||||||
# after first residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_generator_block = {
|
|
||||||
"16": 6,
|
|
||||||
"32": 9,
|
|
||||||
"64": 12,
|
|
||||||
"128": 15,
|
|
||||||
"256": 18,
|
|
||||||
"512": 21,
|
|
||||||
}
|
|
||||||
|
|
||||||
# fuse_convs_dict
|
|
||||||
self.fuse_convs_dict = nn.ModuleDict()
|
|
||||||
for f_size in self.connect_list:
|
|
||||||
in_ch = self.channels[f_size]
|
|
||||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
|
||||||
|
|
||||||
self.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
|
|
||||||
def forward(self, x, weight=0.5, **kwargs):
|
|
||||||
detach_16 = True
|
|
||||||
code_only = False
|
|
||||||
adain = True
|
|
||||||
# ################### Encoder #####################
|
|
||||||
enc_feat_dict = {}
|
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in out_list:
|
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
|
||||||
|
|
||||||
lq_feat = x
|
|
||||||
# ################# Transformer ###################
|
|
||||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
|
||||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
|
|
||||||
# BCHW -> BC(HW) -> (HW)BC
|
|
||||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
|
|
||||||
query_emb = feat_emb
|
|
||||||
# Transformer encoder
|
|
||||||
for layer in self.ft_layers:
|
|
||||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
|
||||||
|
|
||||||
# output logits
|
|
||||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
|
||||||
logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
|
|
||||||
|
|
||||||
if code_only: # for training stage II
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
return logits, lq_feat
|
|
||||||
|
|
||||||
# ################# Quantization ###################
|
|
||||||
# if self.training:
|
|
||||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
|
||||||
# # b(hw)c -> bc(hw) -> bchw
|
|
||||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
|
||||||
# ------------
|
|
||||||
soft_one_hot = F.softmax(logits, dim=2)
|
|
||||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
|
||||||
quant_feat = self.quantize.get_codebook_feat(
|
|
||||||
top_idx, shape=[x.shape[0], 16, 16, 256] # type: ignore
|
|
||||||
)
|
|
||||||
# preserve gradients
|
|
||||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
|
||||||
|
|
||||||
if detach_16:
|
|
||||||
quant_feat = quant_feat.detach() # for training stage III
|
|
||||||
if adain:
|
|
||||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
|
||||||
|
|
||||||
# ################## Generator ####################
|
|
||||||
x = quant_feat
|
|
||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in fuse_list: # fuse after i-th block
|
|
||||||
f_size = str(x.shape[-1])
|
|
||||||
if weight > 0:
|
|
||||||
x = self.fuse_convs_dict[f_size](
|
|
||||||
enc_feat_dict[f_size].detach(), x, weight
|
|
||||||
)
|
|
||||||
out = x
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
# return out, logits, lq_feat
|
|
||||||
return out, logits
|
|
||||||
@ -1,81 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# type: ignore
|
|
||||||
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.autograd import Function
|
|
||||||
|
|
||||||
fused_act_ext = None
|
|
||||||
|
|
||||||
|
|
||||||
class FusedLeakyReLUFunctionBackward(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, grad_output, out, negative_slope, scale):
|
|
||||||
ctx.save_for_backward(out)
|
|
||||||
ctx.negative_slope = negative_slope
|
|
||||||
ctx.scale = scale
|
|
||||||
|
|
||||||
empty = grad_output.new_empty(0)
|
|
||||||
|
|
||||||
grad_input = fused_act_ext.fused_bias_act(
|
|
||||||
grad_output, empty, out, 3, 1, negative_slope, scale
|
|
||||||
)
|
|
||||||
|
|
||||||
dim = [0]
|
|
||||||
|
|
||||||
if grad_input.ndim > 2:
|
|
||||||
dim += list(range(2, grad_input.ndim))
|
|
||||||
|
|
||||||
grad_bias = grad_input.sum(dim).detach()
|
|
||||||
|
|
||||||
return grad_input, grad_bias
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, gradgrad_input, gradgrad_bias):
|
|
||||||
(out,) = ctx.saved_tensors
|
|
||||||
gradgrad_out = fused_act_ext.fused_bias_act(
|
|
||||||
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
|
||||||
)
|
|
||||||
|
|
||||||
return gradgrad_out, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class FusedLeakyReLUFunction(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input, bias, negative_slope, scale):
|
|
||||||
empty = input.new_empty(0)
|
|
||||||
out = fused_act_ext.fused_bias_act(
|
|
||||||
input, bias, empty, 3, 0, negative_slope, scale
|
|
||||||
)
|
|
||||||
ctx.save_for_backward(out)
|
|
||||||
ctx.negative_slope = negative_slope
|
|
||||||
ctx.scale = scale
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
(out,) = ctx.saved_tensors
|
|
||||||
|
|
||||||
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
|
||||||
grad_output, out, ctx.negative_slope, ctx.scale
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_input, grad_bias, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class FusedLeakyReLU(nn.Module):
|
|
||||||
def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.bias = nn.Parameter(torch.zeros(channel))
|
|
||||||
self.negative_slope = negative_slope
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
|
||||||
|
|
||||||
|
|
||||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
|
|
||||||
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
|
||||||
@ -1,389 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# type: ignore
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from .gfpganv1_arch import ResUpBlock
|
|
||||||
from .stylegan2_bilinear_arch import (
|
|
||||||
ConvLayer,
|
|
||||||
EqualConv2d,
|
|
||||||
EqualLinear,
|
|
||||||
ResBlock,
|
|
||||||
ScaledLeakyReLU,
|
|
||||||
StyleGAN2GeneratorBilinear,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
|
|
||||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
|
||||||
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
|
|
||||||
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
|
|
||||||
Args:
|
|
||||||
out_size (int): The spatial size of outputs.
|
|
||||||
num_style_feat (int): Channel number of style features. Default: 512.
|
|
||||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
|
||||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
|
||||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
|
||||||
narrow (float): The narrow ratio for channels. Default: 1.
|
|
||||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
out_size,
|
|
||||||
num_style_feat=512,
|
|
||||||
num_mlp=8,
|
|
||||||
channel_multiplier=2,
|
|
||||||
lr_mlp=0.01,
|
|
||||||
narrow=1,
|
|
||||||
sft_half=False,
|
|
||||||
):
|
|
||||||
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
|
|
||||||
out_size,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
num_mlp=num_mlp,
|
|
||||||
channel_multiplier=channel_multiplier,
|
|
||||||
lr_mlp=lr_mlp,
|
|
||||||
narrow=narrow,
|
|
||||||
)
|
|
||||||
self.sft_half = sft_half
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
styles,
|
|
||||||
conditions,
|
|
||||||
input_is_latent=False,
|
|
||||||
noise=None,
|
|
||||||
randomize_noise=True,
|
|
||||||
truncation=1,
|
|
||||||
truncation_latent=None,
|
|
||||||
inject_index=None,
|
|
||||||
return_latents=False,
|
|
||||||
):
|
|
||||||
"""Forward function for StyleGAN2GeneratorBilinearSFT.
|
|
||||||
Args:
|
|
||||||
styles (list[Tensor]): Sample codes of styles.
|
|
||||||
conditions (list[Tensor]): SFT conditions to generators.
|
|
||||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
|
||||||
noise (Tensor | None): Input noise or None. Default: None.
|
|
||||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
|
||||||
truncation (float): The truncation ratio. Default: 1.
|
|
||||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
|
||||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
|
||||||
return_latents (bool): Whether to return style latents. Default: False.
|
|
||||||
"""
|
|
||||||
# style codes -> latents with Style MLP layer
|
|
||||||
if not input_is_latent:
|
|
||||||
styles = [self.style_mlp(s) for s in styles]
|
|
||||||
# noises
|
|
||||||
if noise is None:
|
|
||||||
if randomize_noise:
|
|
||||||
noise = [None] * self.num_layers # for each style conv layer
|
|
||||||
else: # use the stored noise
|
|
||||||
noise = [
|
|
||||||
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
|
|
||||||
]
|
|
||||||
# style truncation
|
|
||||||
if truncation < 1:
|
|
||||||
style_truncation = []
|
|
||||||
for style in styles:
|
|
||||||
style_truncation.append(
|
|
||||||
truncation_latent + truncation * (style - truncation_latent)
|
|
||||||
)
|
|
||||||
styles = style_truncation
|
|
||||||
# get style latents with injection
|
|
||||||
if len(styles) == 1:
|
|
||||||
inject_index = self.num_latent
|
|
||||||
|
|
||||||
if styles[0].ndim < 3:
|
|
||||||
# repeat latent code for all the layers
|
|
||||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
else: # used for encoder with different latent code for each layer
|
|
||||||
latent = styles[0]
|
|
||||||
elif len(styles) == 2: # mixing noises
|
|
||||||
if inject_index is None:
|
|
||||||
inject_index = random.randint(1, self.num_latent - 1)
|
|
||||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
latent2 = (
|
|
||||||
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
|
||||||
)
|
|
||||||
latent = torch.cat([latent1, latent2], 1)
|
|
||||||
|
|
||||||
# main generation
|
|
||||||
out = self.constant_input(latent.shape[0])
|
|
||||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
|
||||||
skip = self.to_rgb1(out, latent[:, 1])
|
|
||||||
|
|
||||||
i = 1
|
|
||||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
|
||||||
self.style_convs[::2],
|
|
||||||
self.style_convs[1::2],
|
|
||||||
noise[1::2],
|
|
||||||
noise[2::2],
|
|
||||||
self.to_rgbs,
|
|
||||||
):
|
|
||||||
out = conv1(out, latent[:, i], noise=noise1)
|
|
||||||
|
|
||||||
# the conditions may have fewer levels
|
|
||||||
if i < len(conditions):
|
|
||||||
# SFT part to combine the conditions
|
|
||||||
if self.sft_half: # only apply SFT to half of the channels
|
|
||||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
|
||||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
|
||||||
out = torch.cat([out_same, out_sft], dim=1)
|
|
||||||
else: # apply SFT to all the channels
|
|
||||||
out = out * conditions[i - 1] + conditions[i]
|
|
||||||
|
|
||||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
|
||||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
|
||||||
i += 2
|
|
||||||
|
|
||||||
image = skip
|
|
||||||
|
|
||||||
if return_latents:
|
|
||||||
return image, latent
|
|
||||||
else:
|
|
||||||
return image, None
|
|
||||||
|
|
||||||
|
|
||||||
class GFPGANBilinear(nn.Module):
|
|
||||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
|
||||||
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
|
|
||||||
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
|
|
||||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
|
||||||
Args:
|
|
||||||
out_size (int): The spatial size of outputs.
|
|
||||||
num_style_feat (int): Channel number of style features. Default: 512.
|
|
||||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
|
||||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
|
||||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
|
||||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
|
||||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
|
||||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
|
||||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
|
||||||
narrow (float): The narrow ratio for channels. Default: 1.
|
|
||||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
out_size,
|
|
||||||
num_style_feat=512,
|
|
||||||
channel_multiplier=1,
|
|
||||||
decoder_load_path=None,
|
|
||||||
fix_decoder=True,
|
|
||||||
# for stylegan decoder
|
|
||||||
num_mlp=8,
|
|
||||||
lr_mlp=0.01,
|
|
||||||
input_is_latent=False,
|
|
||||||
different_w=False,
|
|
||||||
narrow=1,
|
|
||||||
sft_half=False,
|
|
||||||
):
|
|
||||||
super(GFPGANBilinear, self).__init__()
|
|
||||||
self.input_is_latent = input_is_latent
|
|
||||||
self.different_w = different_w
|
|
||||||
self.num_style_feat = num_style_feat
|
|
||||||
self.min_size_restriction = 512
|
|
||||||
|
|
||||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
|
||||||
channels = {
|
|
||||||
"4": int(512 * unet_narrow),
|
|
||||||
"8": int(512 * unet_narrow),
|
|
||||||
"16": int(512 * unet_narrow),
|
|
||||||
"32": int(512 * unet_narrow),
|
|
||||||
"64": int(256 * channel_multiplier * unet_narrow),
|
|
||||||
"128": int(128 * channel_multiplier * unet_narrow),
|
|
||||||
"256": int(64 * channel_multiplier * unet_narrow),
|
|
||||||
"512": int(32 * channel_multiplier * unet_narrow),
|
|
||||||
"1024": int(16 * channel_multiplier * unet_narrow),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.log_size = int(math.log(out_size, 2))
|
|
||||||
first_out_size = 2 ** (int(math.log(out_size, 2)))
|
|
||||||
|
|
||||||
self.conv_body_first = ConvLayer(
|
|
||||||
3, channels[f"{first_out_size}"], 1, bias=True, activate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# downsample
|
|
||||||
in_channels = channels[f"{first_out_size}"]
|
|
||||||
self.conv_body_down = nn.ModuleList()
|
|
||||||
for i in range(self.log_size, 2, -1):
|
|
||||||
out_channels = channels[f"{2**(i - 1)}"]
|
|
||||||
self.conv_body_down.append(ResBlock(in_channels, out_channels))
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
self.final_conv = ConvLayer(
|
|
||||||
in_channels, channels["4"], 3, bias=True, activate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# upsample
|
|
||||||
in_channels = channels["4"]
|
|
||||||
self.conv_body_up = nn.ModuleList()
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
out_channels = channels[f"{2**i}"]
|
|
||||||
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
# to RGB
|
|
||||||
self.toRGB = nn.ModuleList()
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
self.toRGB.append(
|
|
||||||
EqualConv2d(
|
|
||||||
channels[f"{2**i}"],
|
|
||||||
3,
|
|
||||||
1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if different_w:
|
|
||||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
|
||||||
else:
|
|
||||||
linear_out_channel = num_style_feat
|
|
||||||
|
|
||||||
self.final_linear = EqualLinear(
|
|
||||||
channels["4"] * 4 * 4,
|
|
||||||
linear_out_channel,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
lr_mul=1,
|
|
||||||
activation=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# the decoder: stylegan2 generator with SFT modulations
|
|
||||||
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
|
|
||||||
out_size=out_size,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
num_mlp=num_mlp,
|
|
||||||
channel_multiplier=channel_multiplier,
|
|
||||||
lr_mlp=lr_mlp,
|
|
||||||
narrow=narrow,
|
|
||||||
sft_half=sft_half,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load pre-trained stylegan2 model if necessary
|
|
||||||
if decoder_load_path:
|
|
||||||
self.stylegan_decoder.load_state_dict(
|
|
||||||
torch.load(
|
|
||||||
decoder_load_path, map_location=lambda storage, loc: storage
|
|
||||||
)["params_ema"]
|
|
||||||
)
|
|
||||||
# fix decoder without updating params
|
|
||||||
if fix_decoder:
|
|
||||||
for _, param in self.stylegan_decoder.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
# for SFT modulations (scale and shift)
|
|
||||||
self.condition_scale = nn.ModuleList()
|
|
||||||
self.condition_shift = nn.ModuleList()
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
out_channels = channels[f"{2**i}"]
|
|
||||||
if sft_half:
|
|
||||||
sft_out_channels = out_channels
|
|
||||||
else:
|
|
||||||
sft_out_channels = out_channels * 2
|
|
||||||
self.condition_scale.append(
|
|
||||||
nn.Sequential(
|
|
||||||
EqualConv2d(
|
|
||||||
out_channels,
|
|
||||||
out_channels,
|
|
||||||
3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
),
|
|
||||||
ScaledLeakyReLU(0.2),
|
|
||||||
EqualConv2d(
|
|
||||||
out_channels,
|
|
||||||
sft_out_channels,
|
|
||||||
3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.condition_shift.append(
|
|
||||||
nn.Sequential(
|
|
||||||
EqualConv2d(
|
|
||||||
out_channels,
|
|
||||||
out_channels,
|
|
||||||
3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
),
|
|
||||||
ScaledLeakyReLU(0.2),
|
|
||||||
EqualConv2d(
|
|
||||||
out_channels,
|
|
||||||
sft_out_channels,
|
|
||||||
3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
|
||||||
"""Forward function for GFPGANBilinear.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input images.
|
|
||||||
return_latents (bool): Whether to return style latents. Default: False.
|
|
||||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
|
||||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
|
||||||
"""
|
|
||||||
conditions = []
|
|
||||||
unet_skips = []
|
|
||||||
out_rgbs = []
|
|
||||||
|
|
||||||
# encoder
|
|
||||||
feat = self.conv_body_first(x)
|
|
||||||
for i in range(self.log_size - 2):
|
|
||||||
feat = self.conv_body_down[i](feat)
|
|
||||||
unet_skips.insert(0, feat)
|
|
||||||
|
|
||||||
feat = self.final_conv(feat)
|
|
||||||
|
|
||||||
# style code
|
|
||||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
|
||||||
if self.different_w:
|
|
||||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
|
||||||
|
|
||||||
# decode
|
|
||||||
for i in range(self.log_size - 2):
|
|
||||||
# add unet skip
|
|
||||||
feat = feat + unet_skips[i]
|
|
||||||
# ResUpLayer
|
|
||||||
feat = self.conv_body_up[i](feat)
|
|
||||||
# generate scale and shift for SFT layers
|
|
||||||
scale = self.condition_scale[i](feat)
|
|
||||||
conditions.append(scale.clone())
|
|
||||||
shift = self.condition_shift[i](feat)
|
|
||||||
conditions.append(shift.clone())
|
|
||||||
# generate rgb images
|
|
||||||
if return_rgb:
|
|
||||||
out_rgbs.append(self.toRGB[i](feat))
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
image, _ = self.stylegan_decoder(
|
|
||||||
[style_code],
|
|
||||||
conditions,
|
|
||||||
return_latents=return_latents,
|
|
||||||
input_is_latent=self.input_is_latent,
|
|
||||||
randomize_noise=randomize_noise,
|
|
||||||
)
|
|
||||||
|
|
||||||
return image, out_rgbs
|
|
||||||
@ -1,566 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# type: ignore
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .fused_act import FusedLeakyReLU
|
|
||||||
from .stylegan2_arch import (
|
|
||||||
ConvLayer,
|
|
||||||
EqualConv2d,
|
|
||||||
EqualLinear,
|
|
||||||
ResBlock,
|
|
||||||
ScaledLeakyReLU,
|
|
||||||
StyleGAN2Generator,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
|
||||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
|
||||||
Args:
|
|
||||||
out_size (int): The spatial size of outputs.
|
|
||||||
num_style_feat (int): Channel number of style features. Default: 512.
|
|
||||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
|
||||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
|
||||||
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
|
||||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
|
||||||
narrow (float): The narrow ratio for channels. Default: 1.
|
|
||||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
out_size,
|
|
||||||
num_style_feat=512,
|
|
||||||
num_mlp=8,
|
|
||||||
channel_multiplier=2,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
lr_mlp=0.01,
|
|
||||||
narrow=1,
|
|
||||||
sft_half=False,
|
|
||||||
):
|
|
||||||
super(StyleGAN2GeneratorSFT, self).__init__(
|
|
||||||
out_size,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
num_mlp=num_mlp,
|
|
||||||
channel_multiplier=channel_multiplier,
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
lr_mlp=lr_mlp,
|
|
||||||
narrow=narrow,
|
|
||||||
)
|
|
||||||
self.sft_half = sft_half
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
styles,
|
|
||||||
conditions,
|
|
||||||
input_is_latent=False,
|
|
||||||
noise=None,
|
|
||||||
randomize_noise=True,
|
|
||||||
truncation=1,
|
|
||||||
truncation_latent=None,
|
|
||||||
inject_index=None,
|
|
||||||
return_latents=False,
|
|
||||||
):
|
|
||||||
"""Forward function for StyleGAN2GeneratorSFT.
|
|
||||||
Args:
|
|
||||||
styles (list[Tensor]): Sample codes of styles.
|
|
||||||
conditions (list[Tensor]): SFT conditions to generators.
|
|
||||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
|
||||||
noise (Tensor | None): Input noise or None. Default: None.
|
|
||||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
|
||||||
truncation (float): The truncation ratio. Default: 1.
|
|
||||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
|
||||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
|
||||||
return_latents (bool): Whether to return style latents. Default: False.
|
|
||||||
"""
|
|
||||||
# style codes -> latents with Style MLP layer
|
|
||||||
if not input_is_latent:
|
|
||||||
styles = [self.style_mlp(s) for s in styles]
|
|
||||||
# noises
|
|
||||||
if noise is None:
|
|
||||||
if randomize_noise:
|
|
||||||
noise = [None] * self.num_layers # for each style conv layer
|
|
||||||
else: # use the stored noise
|
|
||||||
noise = [
|
|
||||||
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
|
|
||||||
]
|
|
||||||
# style truncation
|
|
||||||
if truncation < 1:
|
|
||||||
style_truncation = []
|
|
||||||
for style in styles:
|
|
||||||
style_truncation.append(
|
|
||||||
truncation_latent + truncation * (style - truncation_latent)
|
|
||||||
)
|
|
||||||
styles = style_truncation
|
|
||||||
# get style latents with injection
|
|
||||||
if len(styles) == 1:
|
|
||||||
inject_index = self.num_latent
|
|
||||||
|
|
||||||
if styles[0].ndim < 3:
|
|
||||||
# repeat latent code for all the layers
|
|
||||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
else: # used for encoder with different latent code for each layer
|
|
||||||
latent = styles[0]
|
|
||||||
elif len(styles) == 2: # mixing noises
|
|
||||||
if inject_index is None:
|
|
||||||
inject_index = random.randint(1, self.num_latent - 1)
|
|
||||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
latent2 = (
|
|
||||||
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
|
||||||
)
|
|
||||||
latent = torch.cat([latent1, latent2], 1)
|
|
||||||
|
|
||||||
# main generation
|
|
||||||
out = self.constant_input(latent.shape[0])
|
|
||||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
|
||||||
skip = self.to_rgb1(out, latent[:, 1])
|
|
||||||
|
|
||||||
i = 1
|
|
||||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
|
||||||
self.style_convs[::2],
|
|
||||||
self.style_convs[1::2],
|
|
||||||
noise[1::2],
|
|
||||||
noise[2::2],
|
|
||||||
self.to_rgbs,
|
|
||||||
):
|
|
||||||
out = conv1(out, latent[:, i], noise=noise1)
|
|
||||||
|
|
||||||
# the conditions may have fewer levels
|
|
||||||
if i < len(conditions):
|
|
||||||
# SFT part to combine the conditions
|
|
||||||
if self.sft_half: # only apply SFT to half of the channels
|
|
||||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
|
||||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
|
||||||
out = torch.cat([out_same, out_sft], dim=1)
|
|
||||||
else: # apply SFT to all the channels
|
|
||||||
out = out * conditions[i - 1] + conditions[i]
|
|
||||||
|
|
||||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
|
||||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
|
||||||
i += 2
|
|
||||||
|
|
||||||
image = skip
|
|
||||||
|
|
||||||
if return_latents:
|
|
||||||
return image, latent
|
|
||||||
else:
|
|
||||||
return image, None
|
|
||||||
|
|
||||||
|
|
||||||
class ConvUpLayer(nn.Module):
|
|
||||||
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Size of the convolving kernel.
|
|
||||||
stride (int): Stride of the convolution. Default: 1
|
|
||||||
padding (int): Zero-padding added to both sides of the input. Default: 0.
|
|
||||||
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
|
|
||||||
bias_init_val (float): Bias initialized value. Default: 0.
|
|
||||||
activate (bool): Whether use activateion. Default: True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
activate=True,
|
|
||||||
):
|
|
||||||
super(ConvUpLayer, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
# self.scale is used to scale the convolution weights, which is related to the common initializations.
|
|
||||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(
|
|
||||||
torch.randn(out_channels, in_channels, kernel_size, kernel_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
if bias and not activate:
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
# activation
|
|
||||||
if activate:
|
|
||||||
if bias:
|
|
||||||
self.activation = FusedLeakyReLU(out_channels)
|
|
||||||
else:
|
|
||||||
self.activation = ScaledLeakyReLU(0.2)
|
|
||||||
else:
|
|
||||||
self.activation = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# bilinear upsample
|
|
||||||
out = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
|
|
||||||
# conv
|
|
||||||
out = F.conv2d(
|
|
||||||
out,
|
|
||||||
self.weight * self.scale,
|
|
||||||
bias=self.bias,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=self.padding,
|
|
||||||
)
|
|
||||||
# activation
|
|
||||||
if self.activation is not None:
|
|
||||||
out = self.activation(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ResUpBlock(nn.Module):
|
|
||||||
"""Residual block with upsampling.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels):
|
|
||||||
super(ResUpBlock, self).__init__()
|
|
||||||
|
|
||||||
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
|
||||||
self.conv2 = ConvUpLayer(
|
|
||||||
in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True
|
|
||||||
)
|
|
||||||
self.skip = ConvUpLayer(
|
|
||||||
in_channels, out_channels, 1, bias=False, activate=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.conv2(out)
|
|
||||||
skip = self.skip(x)
|
|
||||||
out = (out + skip) / math.sqrt(2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class GFPGANv1(nn.Module):
|
|
||||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
|
||||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
|
||||||
Args:
|
|
||||||
out_size (int): The spatial size of outputs.
|
|
||||||
num_style_feat (int): Channel number of style features. Default: 512.
|
|
||||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
|
||||||
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
|
||||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
|
||||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
|
||||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
|
||||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
|
||||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
|
||||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
|
||||||
narrow (float): The narrow ratio for channels. Default: 1.
|
|
||||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
out_size,
|
|
||||||
num_style_feat=512,
|
|
||||||
channel_multiplier=1,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
decoder_load_path=None,
|
|
||||||
fix_decoder=True,
|
|
||||||
# for stylegan decoder
|
|
||||||
num_mlp=8,
|
|
||||||
lr_mlp=0.01,
|
|
||||||
input_is_latent=False,
|
|
||||||
different_w=False,
|
|
||||||
narrow=1,
|
|
||||||
sft_half=False,
|
|
||||||
):
|
|
||||||
super(GFPGANv1, self).__init__()
|
|
||||||
self.input_is_latent = input_is_latent
|
|
||||||
self.different_w = different_w
|
|
||||||
self.num_style_feat = num_style_feat
|
|
||||||
|
|
||||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
|
||||||
channels = {
|
|
||||||
"4": int(512 * unet_narrow),
|
|
||||||
"8": int(512 * unet_narrow),
|
|
||||||
"16": int(512 * unet_narrow),
|
|
||||||
"32": int(512 * unet_narrow),
|
|
||||||
"64": int(256 * channel_multiplier * unet_narrow),
|
|
||||||
"128": int(128 * channel_multiplier * unet_narrow),
|
|
||||||
"256": int(64 * channel_multiplier * unet_narrow),
|
|
||||||
"512": int(32 * channel_multiplier * unet_narrow),
|
|
||||||
"1024": int(16 * channel_multiplier * unet_narrow),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.log_size = int(math.log(out_size, 2))
|
|
||||||
first_out_size = 2 ** (int(math.log(out_size, 2)))
|
|
||||||
|
|
||||||
self.conv_body_first = ConvLayer(
|
|
||||||
3, channels[f"{first_out_size}"], 1, bias=True, activate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# downsample
|
|
||||||
in_channels = channels[f"{first_out_size}"]
|
|
||||||
self.conv_body_down = nn.ModuleList()
|
|
||||||
for i in range(self.log_size, 2, -1):
|
|
||||||
out_channels = channels[f"{2**(i - 1)}"]
|
|
||||||
self.conv_body_down.append(
|
|
||||||
ResBlock(in_channels, out_channels, resample_kernel)
|
|
||||||
)
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
self.final_conv = ConvLayer(
|
|
||||||
in_channels, channels["4"], 3, bias=True, activate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# upsample
|
|
||||||
in_channels = channels["4"]
|
|
||||||
self.conv_body_up = nn.ModuleList()
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
out_channels = channels[f"{2**i}"]
|
|
||||||
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
# to RGB
|
|
||||||
self.toRGB = nn.ModuleList()
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
self.toRGB.append(
|
|
||||||
EqualConv2d(
|
|
||||||
channels[f"{2**i}"],
|
|
||||||
3,
|
|
||||||
1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if different_w:
|
|
||||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
|
||||||
else:
|
|
||||||
linear_out_channel = num_style_feat
|
|
||||||
|
|
||||||
self.final_linear = EqualLinear(
|
|
||||||
channels["4"] * 4 * 4,
|
|
||||||
linear_out_channel,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
lr_mul=1,
|
|
||||||
activation=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# the decoder: stylegan2 generator with SFT modulations
|
|
||||||
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
|
||||||
out_size=out_size,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
num_mlp=num_mlp,
|
|
||||||
channel_multiplier=channel_multiplier,
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
lr_mlp=lr_mlp,
|
|
||||||
narrow=narrow,
|
|
||||||
sft_half=sft_half,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load pre-trained stylegan2 model if necessary
|
|
||||||
if decoder_load_path:
|
|
||||||
self.stylegan_decoder.load_state_dict(
|
|
||||||
torch.load(
|
|
||||||
decoder_load_path, map_location=lambda storage, loc: storage
|
|
||||||
)["params_ema"]
|
|
||||||
)
|
|
||||||
# fix decoder without updating params
|
|
||||||
if fix_decoder:
|
|
||||||
for _, param in self.stylegan_decoder.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
# for SFT modulations (scale and shift)
|
|
||||||
self.condition_scale = nn.ModuleList()
|
|
||||||
self.condition_shift = nn.ModuleList()
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
out_channels = channels[f"{2**i}"]
|
|
||||||
if sft_half:
|
|
||||||
sft_out_channels = out_channels
|
|
||||||
else:
|
|
||||||
sft_out_channels = out_channels * 2
|
|
||||||
self.condition_scale.append(
|
|
||||||
nn.Sequential(
|
|
||||||
EqualConv2d(
|
|
||||||
out_channels,
|
|
||||||
out_channels,
|
|
||||||
3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
),
|
|
||||||
ScaledLeakyReLU(0.2),
|
|
||||||
EqualConv2d(
|
|
||||||
out_channels,
|
|
||||||
sft_out_channels,
|
|
||||||
3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.condition_shift.append(
|
|
||||||
nn.Sequential(
|
|
||||||
EqualConv2d(
|
|
||||||
out_channels,
|
|
||||||
out_channels,
|
|
||||||
3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
),
|
|
||||||
ScaledLeakyReLU(0.2),
|
|
||||||
EqualConv2d(
|
|
||||||
out_channels,
|
|
||||||
sft_out_channels,
|
|
||||||
3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
|
|
||||||
):
|
|
||||||
"""Forward function for GFPGANv1.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input images.
|
|
||||||
return_latents (bool): Whether to return style latents. Default: False.
|
|
||||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
|
||||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
|
||||||
"""
|
|
||||||
conditions = []
|
|
||||||
unet_skips = []
|
|
||||||
out_rgbs = []
|
|
||||||
|
|
||||||
# encoder
|
|
||||||
feat = self.conv_body_first(x)
|
|
||||||
for i in range(self.log_size - 2):
|
|
||||||
feat = self.conv_body_down[i](feat)
|
|
||||||
unet_skips.insert(0, feat)
|
|
||||||
|
|
||||||
feat = self.final_conv(feat)
|
|
||||||
|
|
||||||
# style code
|
|
||||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
|
||||||
if self.different_w:
|
|
||||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
|
||||||
|
|
||||||
# decode
|
|
||||||
for i in range(self.log_size - 2):
|
|
||||||
# add unet skip
|
|
||||||
feat = feat + unet_skips[i]
|
|
||||||
# ResUpLayer
|
|
||||||
feat = self.conv_body_up[i](feat)
|
|
||||||
# generate scale and shift for SFT layers
|
|
||||||
scale = self.condition_scale[i](feat)
|
|
||||||
conditions.append(scale.clone())
|
|
||||||
shift = self.condition_shift[i](feat)
|
|
||||||
conditions.append(shift.clone())
|
|
||||||
# generate rgb images
|
|
||||||
if return_rgb:
|
|
||||||
out_rgbs.append(self.toRGB[i](feat))
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
image, _ = self.stylegan_decoder(
|
|
||||||
[style_code],
|
|
||||||
conditions,
|
|
||||||
return_latents=return_latents,
|
|
||||||
input_is_latent=self.input_is_latent,
|
|
||||||
randomize_noise=randomize_noise,
|
|
||||||
)
|
|
||||||
|
|
||||||
return image, out_rgbs
|
|
||||||
|
|
||||||
|
|
||||||
class FacialComponentDiscriminator(nn.Module):
|
|
||||||
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(FacialComponentDiscriminator, self).__init__()
|
|
||||||
# It now uses a VGG-style architectrue with fixed model size
|
|
||||||
self.conv1 = ConvLayer(
|
|
||||||
3,
|
|
||||||
64,
|
|
||||||
3,
|
|
||||||
downsample=False,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
)
|
|
||||||
self.conv2 = ConvLayer(
|
|
||||||
64,
|
|
||||||
128,
|
|
||||||
3,
|
|
||||||
downsample=True,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
)
|
|
||||||
self.conv3 = ConvLayer(
|
|
||||||
128,
|
|
||||||
128,
|
|
||||||
3,
|
|
||||||
downsample=False,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
)
|
|
||||||
self.conv4 = ConvLayer(
|
|
||||||
128,
|
|
||||||
256,
|
|
||||||
3,
|
|
||||||
downsample=True,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
)
|
|
||||||
self.conv5 = ConvLayer(
|
|
||||||
256,
|
|
||||||
256,
|
|
||||||
3,
|
|
||||||
downsample=False,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
)
|
|
||||||
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
|
||||||
|
|
||||||
def forward(self, x, return_feats=False, **kwargs):
|
|
||||||
"""Forward function for FacialComponentDiscriminator.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input images.
|
|
||||||
return_feats (bool): Whether to return intermediate features. Default: False.
|
|
||||||
"""
|
|
||||||
feat = self.conv1(x)
|
|
||||||
feat = self.conv3(self.conv2(feat))
|
|
||||||
rlt_feats = []
|
|
||||||
if return_feats:
|
|
||||||
rlt_feats.append(feat.clone())
|
|
||||||
feat = self.conv5(self.conv4(feat))
|
|
||||||
if return_feats:
|
|
||||||
rlt_feats.append(feat.clone())
|
|
||||||
out = self.final_conv(feat)
|
|
||||||
|
|
||||||
if return_feats:
|
|
||||||
return out, rlt_feats
|
|
||||||
else:
|
|
||||||
return out, None
|
|
||||||
@ -1,370 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# type: ignore
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
|
||||||
|
|
||||||
|
|
||||||
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
|
||||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
|
||||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
|
||||||
Args:
|
|
||||||
out_size (int): The spatial size of outputs.
|
|
||||||
num_style_feat (int): Channel number of style features. Default: 512.
|
|
||||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
|
||||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
|
||||||
narrow (float): The narrow ratio for channels. Default: 1.
|
|
||||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
out_size,
|
|
||||||
num_style_feat=512,
|
|
||||||
num_mlp=8,
|
|
||||||
channel_multiplier=2,
|
|
||||||
narrow=1,
|
|
||||||
sft_half=False,
|
|
||||||
):
|
|
||||||
super(StyleGAN2GeneratorCSFT, self).__init__(
|
|
||||||
out_size,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
num_mlp=num_mlp,
|
|
||||||
channel_multiplier=channel_multiplier,
|
|
||||||
narrow=narrow,
|
|
||||||
)
|
|
||||||
self.sft_half = sft_half
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
styles,
|
|
||||||
conditions,
|
|
||||||
input_is_latent=False,
|
|
||||||
noise=None,
|
|
||||||
randomize_noise=True,
|
|
||||||
truncation=1,
|
|
||||||
truncation_latent=None,
|
|
||||||
inject_index=None,
|
|
||||||
return_latents=False,
|
|
||||||
):
|
|
||||||
"""Forward function for StyleGAN2GeneratorCSFT.
|
|
||||||
Args:
|
|
||||||
styles (list[Tensor]): Sample codes of styles.
|
|
||||||
conditions (list[Tensor]): SFT conditions to generators.
|
|
||||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
|
||||||
noise (Tensor | None): Input noise or None. Default: None.
|
|
||||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
|
||||||
truncation (float): The truncation ratio. Default: 1.
|
|
||||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
|
||||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
|
||||||
return_latents (bool): Whether to return style latents. Default: False.
|
|
||||||
"""
|
|
||||||
# style codes -> latents with Style MLP layer
|
|
||||||
if not input_is_latent:
|
|
||||||
styles = [self.style_mlp(s) for s in styles]
|
|
||||||
# noises
|
|
||||||
if noise is None:
|
|
||||||
if randomize_noise:
|
|
||||||
noise = [None] * self.num_layers # for each style conv layer
|
|
||||||
else: # use the stored noise
|
|
||||||
noise = [
|
|
||||||
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
|
|
||||||
]
|
|
||||||
# style truncation
|
|
||||||
if truncation < 1:
|
|
||||||
style_truncation = []
|
|
||||||
for style in styles:
|
|
||||||
style_truncation.append(
|
|
||||||
truncation_latent + truncation * (style - truncation_latent)
|
|
||||||
)
|
|
||||||
styles = style_truncation
|
|
||||||
# get style latents with injection
|
|
||||||
if len(styles) == 1:
|
|
||||||
inject_index = self.num_latent
|
|
||||||
|
|
||||||
if styles[0].ndim < 3:
|
|
||||||
# repeat latent code for all the layers
|
|
||||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
else: # used for encoder with different latent code for each layer
|
|
||||||
latent = styles[0]
|
|
||||||
elif len(styles) == 2: # mixing noises
|
|
||||||
if inject_index is None:
|
|
||||||
inject_index = random.randint(1, self.num_latent - 1)
|
|
||||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
latent2 = (
|
|
||||||
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
|
||||||
)
|
|
||||||
latent = torch.cat([latent1, latent2], 1)
|
|
||||||
|
|
||||||
# main generation
|
|
||||||
out = self.constant_input(latent.shape[0])
|
|
||||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
|
||||||
skip = self.to_rgb1(out, latent[:, 1])
|
|
||||||
|
|
||||||
i = 1
|
|
||||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
|
||||||
self.style_convs[::2],
|
|
||||||
self.style_convs[1::2],
|
|
||||||
noise[1::2],
|
|
||||||
noise[2::2],
|
|
||||||
self.to_rgbs,
|
|
||||||
):
|
|
||||||
out = conv1(out, latent[:, i], noise=noise1)
|
|
||||||
|
|
||||||
# the conditions may have fewer levels
|
|
||||||
if i < len(conditions):
|
|
||||||
# SFT part to combine the conditions
|
|
||||||
if self.sft_half: # only apply SFT to half of the channels
|
|
||||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
|
||||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
|
||||||
out = torch.cat([out_same, out_sft], dim=1)
|
|
||||||
else: # apply SFT to all the channels
|
|
||||||
out = out * conditions[i - 1] + conditions[i]
|
|
||||||
|
|
||||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
|
||||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
|
||||||
i += 2
|
|
||||||
|
|
||||||
image = skip
|
|
||||||
|
|
||||||
if return_latents:
|
|
||||||
return image, latent
|
|
||||||
else:
|
|
||||||
return image, None
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
"""Residual block with bilinear upsampling/downsampling.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, mode="down"):
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
|
||||||
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
|
||||||
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
|
||||||
if mode == "down":
|
|
||||||
self.scale_factor = 0.5
|
|
||||||
elif mode == "up":
|
|
||||||
self.scale_factor = 2
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
|
||||||
# upsample/downsample
|
|
||||||
out = F.interpolate(
|
|
||||||
out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
|
||||||
# skip
|
|
||||||
x = F.interpolate(
|
|
||||||
x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
skip = self.skip(x)
|
|
||||||
out = out + skip
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class GFPGANv1Clean(nn.Module):
|
|
||||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
|
||||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
|
||||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
|
||||||
Args:
|
|
||||||
out_size (int): The spatial size of outputs.
|
|
||||||
num_style_feat (int): Channel number of style features. Default: 512.
|
|
||||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
|
||||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
|
||||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
|
||||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
|
||||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
|
||||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
|
||||||
narrow (float): The narrow ratio for channels. Default: 1.
|
|
||||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
):
|
|
||||||
super(GFPGANv1Clean, self).__init__()
|
|
||||||
|
|
||||||
out_size = 512
|
|
||||||
num_style_feat = 512
|
|
||||||
channel_multiplier = 2
|
|
||||||
decoder_load_path = None
|
|
||||||
fix_decoder = False
|
|
||||||
num_mlp = 8
|
|
||||||
input_is_latent = True
|
|
||||||
different_w = True
|
|
||||||
narrow = 1
|
|
||||||
sft_half = True
|
|
||||||
|
|
||||||
self.model_arch = "GFPGAN"
|
|
||||||
self.sub_type = "Face SR"
|
|
||||||
self.scale = 8
|
|
||||||
self.in_nc = 3
|
|
||||||
self.out_nc = 3
|
|
||||||
self.state = state_dict
|
|
||||||
|
|
||||||
self.supports_fp16 = False
|
|
||||||
self.supports_bf16 = True
|
|
||||||
self.min_size_restriction = 512
|
|
||||||
|
|
||||||
self.input_is_latent = input_is_latent
|
|
||||||
self.different_w = different_w
|
|
||||||
self.num_style_feat = num_style_feat
|
|
||||||
|
|
||||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
|
||||||
channels = {
|
|
||||||
"4": int(512 * unet_narrow),
|
|
||||||
"8": int(512 * unet_narrow),
|
|
||||||
"16": int(512 * unet_narrow),
|
|
||||||
"32": int(512 * unet_narrow),
|
|
||||||
"64": int(256 * channel_multiplier * unet_narrow),
|
|
||||||
"128": int(128 * channel_multiplier * unet_narrow),
|
|
||||||
"256": int(64 * channel_multiplier * unet_narrow),
|
|
||||||
"512": int(32 * channel_multiplier * unet_narrow),
|
|
||||||
"1024": int(16 * channel_multiplier * unet_narrow),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.log_size = int(math.log(out_size, 2))
|
|
||||||
first_out_size = 2 ** (int(math.log(out_size, 2)))
|
|
||||||
|
|
||||||
self.conv_body_first = nn.Conv2d(3, channels[f"{first_out_size}"], 1)
|
|
||||||
|
|
||||||
# downsample
|
|
||||||
in_channels = channels[f"{first_out_size}"]
|
|
||||||
self.conv_body_down = nn.ModuleList()
|
|
||||||
for i in range(self.log_size, 2, -1):
|
|
||||||
out_channels = channels[f"{2**(i - 1)}"]
|
|
||||||
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode="down"))
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
self.final_conv = nn.Conv2d(in_channels, channels["4"], 3, 1, 1)
|
|
||||||
|
|
||||||
# upsample
|
|
||||||
in_channels = channels["4"]
|
|
||||||
self.conv_body_up = nn.ModuleList()
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
out_channels = channels[f"{2**i}"]
|
|
||||||
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode="up"))
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
# to RGB
|
|
||||||
self.toRGB = nn.ModuleList()
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
self.toRGB.append(nn.Conv2d(channels[f"{2**i}"], 3, 1))
|
|
||||||
|
|
||||||
if different_w:
|
|
||||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
|
||||||
else:
|
|
||||||
linear_out_channel = num_style_feat
|
|
||||||
|
|
||||||
self.final_linear = nn.Linear(channels["4"] * 4 * 4, linear_out_channel)
|
|
||||||
|
|
||||||
# the decoder: stylegan2 generator with SFT modulations
|
|
||||||
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
|
||||||
out_size=out_size,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
num_mlp=num_mlp,
|
|
||||||
channel_multiplier=channel_multiplier,
|
|
||||||
narrow=narrow,
|
|
||||||
sft_half=sft_half,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load pre-trained stylegan2 model if necessary
|
|
||||||
if decoder_load_path:
|
|
||||||
self.stylegan_decoder.load_state_dict(
|
|
||||||
torch.load(
|
|
||||||
decoder_load_path, map_location=lambda storage, loc: storage
|
|
||||||
)["params_ema"]
|
|
||||||
)
|
|
||||||
# fix decoder without updating params
|
|
||||||
if fix_decoder:
|
|
||||||
for _, param in self.stylegan_decoder.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
# for SFT modulations (scale and shift)
|
|
||||||
self.condition_scale = nn.ModuleList()
|
|
||||||
self.condition_shift = nn.ModuleList()
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
out_channels = channels[f"{2**i}"]
|
|
||||||
if sft_half:
|
|
||||||
sft_out_channels = out_channels
|
|
||||||
else:
|
|
||||||
sft_out_channels = out_channels * 2
|
|
||||||
self.condition_scale.append(
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.condition_shift.append(
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
|
|
||||||
):
|
|
||||||
"""Forward function for GFPGANv1Clean.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input images.
|
|
||||||
return_latents (bool): Whether to return style latents. Default: False.
|
|
||||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
|
||||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
|
||||||
"""
|
|
||||||
conditions = []
|
|
||||||
unet_skips = []
|
|
||||||
out_rgbs = []
|
|
||||||
|
|
||||||
# encoder
|
|
||||||
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
|
||||||
for i in range(self.log_size - 2):
|
|
||||||
feat = self.conv_body_down[i](feat)
|
|
||||||
unet_skips.insert(0, feat)
|
|
||||||
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
|
||||||
|
|
||||||
# style code
|
|
||||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
|
||||||
if self.different_w:
|
|
||||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
|
||||||
|
|
||||||
# decode
|
|
||||||
for i in range(self.log_size - 2):
|
|
||||||
# add unet skip
|
|
||||||
feat = feat + unet_skips[i]
|
|
||||||
# ResUpLayer
|
|
||||||
feat = self.conv_body_up[i](feat)
|
|
||||||
# generate scale and shift for SFT layers
|
|
||||||
scale = self.condition_scale[i](feat)
|
|
||||||
conditions.append(scale.clone())
|
|
||||||
shift = self.condition_shift[i](feat)
|
|
||||||
conditions.append(shift.clone())
|
|
||||||
# generate rgb images
|
|
||||||
if return_rgb:
|
|
||||||
out_rgbs.append(self.toRGB[i](feat))
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
image, _ = self.stylegan_decoder(
|
|
||||||
[style_code],
|
|
||||||
conditions,
|
|
||||||
return_latents=return_latents,
|
|
||||||
input_is_latent=self.input_is_latent,
|
|
||||||
randomize_noise=randomize_noise,
|
|
||||||
)
|
|
||||||
|
|
||||||
return image, out_rgbs
|
|
||||||
@ -1,776 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# type: ignore
|
|
||||||
"""Modified from https://github.com/wzhouxiff/RestoreFormer
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class VectorQuantizer(nn.Module):
|
|
||||||
"""
|
|
||||||
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
|
||||||
____________________________________________
|
|
||||||
Discretization bottleneck part of the VQ-VAE.
|
|
||||||
Inputs:
|
|
||||||
- n_e : number of embeddings
|
|
||||||
- e_dim : dimension of embedding
|
|
||||||
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
|
||||||
_____________________________________________
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n_e, e_dim, beta):
|
|
||||||
super(VectorQuantizer, self).__init__()
|
|
||||||
self.n_e = n_e
|
|
||||||
self.e_dim = e_dim
|
|
||||||
self.beta = beta
|
|
||||||
|
|
||||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
|
||||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
"""
|
|
||||||
Inputs the output of the encoder network z and maps it to a discrete
|
|
||||||
one-hot vector that is the index of the closest embedding vector e_j
|
|
||||||
z (continuous) -> z_q (discrete)
|
|
||||||
z.shape = (batch, channel, height, width)
|
|
||||||
quantization pipeline:
|
|
||||||
1. get encoder input (B,C,H,W)
|
|
||||||
2. flatten input to (B*H*W,C)
|
|
||||||
"""
|
|
||||||
# reshape z -> (batch, height, width, channel) and flatten
|
|
||||||
z = z.permute(0, 2, 3, 1).contiguous()
|
|
||||||
z_flattened = z.view(-1, self.e_dim)
|
|
||||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
|
||||||
|
|
||||||
d = (
|
|
||||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
|
||||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
|
||||||
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
|
||||||
)
|
|
||||||
|
|
||||||
# could possible replace this here
|
|
||||||
# #\start...
|
|
||||||
# find closest encodings
|
|
||||||
|
|
||||||
min_value, min_encoding_indices = torch.min(d, dim=1)
|
|
||||||
|
|
||||||
min_encoding_indices = min_encoding_indices.unsqueeze(1)
|
|
||||||
|
|
||||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
|
|
||||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
|
||||||
|
|
||||||
# dtype min encodings: torch.float32
|
|
||||||
# min_encodings shape: torch.Size([2048, 512])
|
|
||||||
# min_encoding_indices.shape: torch.Size([2048, 1])
|
|
||||||
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
|
||||||
# .........\end
|
|
||||||
|
|
||||||
# with:
|
|
||||||
# .........\start
|
|
||||||
# min_encoding_indices = torch.argmin(d, dim=1)
|
|
||||||
# z_q = self.embedding(min_encoding_indices)
|
|
||||||
# ......\end......... (TODO)
|
|
||||||
|
|
||||||
# compute loss for embedding
|
|
||||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
|
||||||
(z_q - z.detach()) ** 2
|
|
||||||
)
|
|
||||||
|
|
||||||
# preserve gradients
|
|
||||||
z_q = z + (z_q - z).detach()
|
|
||||||
|
|
||||||
# perplexity
|
|
||||||
|
|
||||||
e_mean = torch.mean(min_encodings, dim=0)
|
|
||||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
|
||||||
|
|
||||||
# reshape back to match original input shape
|
|
||||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
|
|
||||||
|
|
||||||
def get_codebook_entry(self, indices, shape):
|
|
||||||
# shape specifying (batch, height, width, channel)
|
|
||||||
# TODO: check for more easy handling with nn.Embedding
|
|
||||||
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
|
||||||
min_encodings.scatter_(1, indices[:, None], 1)
|
|
||||||
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
|
||||||
|
|
||||||
if shape is not None:
|
|
||||||
z_q = z_q.view(shape)
|
|
||||||
|
|
||||||
# reshape back to match original input shape
|
|
||||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return z_q
|
|
||||||
|
|
||||||
|
|
||||||
# pytorch_diffusion + derived encoder decoder
|
|
||||||
def nonlinearity(x):
|
|
||||||
# swish
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels):
|
|
||||||
return torch.nn.GroupNorm(
|
|
||||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
def __init__(self, in_channels, with_conv):
|
|
||||||
super().__init__()
|
|
||||||
self.with_conv = with_conv
|
|
||||||
if self.with_conv:
|
|
||||||
self.conv = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
||||||
if self.with_conv:
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
|
||||||
def __init__(self, in_channels, with_conv):
|
|
||||||
super().__init__()
|
|
||||||
self.with_conv = with_conv
|
|
||||||
if self.with_conv:
|
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
|
||||||
self.conv = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.with_conv:
|
|
||||||
pad = (0, 1, 0, 1)
|
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
||||||
x = self.conv(x)
|
|
||||||
else:
|
|
||||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
in_channels,
|
|
||||||
out_channels=None,
|
|
||||||
conv_shortcut=False,
|
|
||||||
dropout,
|
|
||||||
temb_channels=512
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.use_conv_shortcut = conv_shortcut
|
|
||||||
|
|
||||||
self.norm1 = Normalize(in_channels)
|
|
||||||
self.conv1 = torch.nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
if temb_channels > 0:
|
|
||||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
|
||||||
self.norm2 = Normalize(out_channels)
|
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
|
||||||
self.conv2 = torch.nn.Conv2d(
|
|
||||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
if self.use_conv_shortcut:
|
|
||||||
self.conv_shortcut = torch.nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.nin_shortcut = torch.nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, temb):
|
|
||||||
h = x
|
|
||||||
h = self.norm1(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.conv1(h)
|
|
||||||
|
|
||||||
if temb is not None:
|
|
||||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
|
||||||
|
|
||||||
h = self.norm2(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.dropout(h)
|
|
||||||
h = self.conv2(h)
|
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
if self.use_conv_shortcut:
|
|
||||||
x = self.conv_shortcut(x)
|
|
||||||
else:
|
|
||||||
x = self.nin_shortcut(x)
|
|
||||||
|
|
||||||
return x + h
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttnBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, head_size=1):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.head_size = head_size
|
|
||||||
self.att_size = in_channels // head_size
|
|
||||||
assert (
|
|
||||||
in_channels % head_size == 0
|
|
||||||
), "The size of head should be divided by the number of channels."
|
|
||||||
|
|
||||||
self.norm1 = Normalize(in_channels)
|
|
||||||
self.norm2 = Normalize(in_channels)
|
|
||||||
|
|
||||||
self.q = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.k = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.v = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.proj_out = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.num = 0
|
|
||||||
|
|
||||||
def forward(self, x, y=None):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm1(h_)
|
|
||||||
if y is None:
|
|
||||||
y = h_
|
|
||||||
else:
|
|
||||||
y = self.norm2(y)
|
|
||||||
|
|
||||||
q = self.q(y)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b, c, h, w = q.shape
|
|
||||||
q = q.reshape(b, self.head_size, self.att_size, h * w)
|
|
||||||
q = q.permute(0, 3, 1, 2) # b, hw, head, att
|
|
||||||
|
|
||||||
k = k.reshape(b, self.head_size, self.att_size, h * w)
|
|
||||||
k = k.permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
v = v.reshape(b, self.head_size, self.att_size, h * w)
|
|
||||||
v = v.permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
q = q.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
k = k.transpose(1, 2).transpose(2, 3)
|
|
||||||
|
|
||||||
scale = int(self.att_size) ** (-0.5)
|
|
||||||
q.mul_(scale)
|
|
||||||
w_ = torch.matmul(q, k)
|
|
||||||
w_ = F.softmax(w_, dim=3)
|
|
||||||
|
|
||||||
w_ = w_.matmul(v)
|
|
||||||
|
|
||||||
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
|
|
||||||
w_ = w_.view(b, h, w, -1)
|
|
||||||
w_ = w_.permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
w_ = self.proj_out(w_)
|
|
||||||
|
|
||||||
return x + w_
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ch,
|
|
||||||
out_ch,
|
|
||||||
ch_mult=(1, 2, 4, 8),
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_resolutions=(16,),
|
|
||||||
dropout=0.0,
|
|
||||||
resamp_with_conv=True,
|
|
||||||
in_channels=3,
|
|
||||||
resolution=512,
|
|
||||||
z_channels=256,
|
|
||||||
double_z=True,
|
|
||||||
enable_mid=True,
|
|
||||||
head_size=1,
|
|
||||||
**ignore_kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.ch = ch
|
|
||||||
self.temb_ch = 0
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.enable_mid = enable_mid
|
|
||||||
|
|
||||||
# downsampling
|
|
||||||
self.conv_in = torch.nn.Conv2d(
|
|
||||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
curr_res = resolution
|
|
||||||
in_ch_mult = (1,) + tuple(ch_mult)
|
|
||||||
self.down = nn.ModuleList()
|
|
||||||
for i_level in range(self.num_resolutions):
|
|
||||||
block = nn.ModuleList()
|
|
||||||
attn = nn.ModuleList()
|
|
||||||
block_in = ch * in_ch_mult[i_level]
|
|
||||||
block_out = ch * ch_mult[i_level]
|
|
||||||
for i_block in range(self.num_res_blocks):
|
|
||||||
block.append(
|
|
||||||
ResnetBlock(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_out,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
block_in = block_out
|
|
||||||
if curr_res in attn_resolutions:
|
|
||||||
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
|
||||||
down = nn.Module()
|
|
||||||
down.block = block
|
|
||||||
down.attn = attn
|
|
||||||
if i_level != self.num_resolutions - 1:
|
|
||||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
|
||||||
curr_res = curr_res // 2
|
|
||||||
self.down.append(down)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
if self.enable_mid:
|
|
||||||
self.mid = nn.Module()
|
|
||||||
self.mid.block_1 = ResnetBlock(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_in,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
|
||||||
self.mid.block_2 = ResnetBlock(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_in,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# end
|
|
||||||
self.norm_out = Normalize(block_in)
|
|
||||||
self.conv_out = torch.nn.Conv2d(
|
|
||||||
block_in,
|
|
||||||
2 * z_channels if double_z else z_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
hs = {}
|
|
||||||
# timestep embedding
|
|
||||||
temb = None
|
|
||||||
|
|
||||||
# downsampling
|
|
||||||
h = self.conv_in(x)
|
|
||||||
hs["in"] = h
|
|
||||||
for i_level in range(self.num_resolutions):
|
|
||||||
for i_block in range(self.num_res_blocks):
|
|
||||||
h = self.down[i_level].block[i_block](h, temb)
|
|
||||||
if len(self.down[i_level].attn) > 0:
|
|
||||||
h = self.down[i_level].attn[i_block](h)
|
|
||||||
|
|
||||||
if i_level != self.num_resolutions - 1:
|
|
||||||
# hs.append(h)
|
|
||||||
hs["block_" + str(i_level)] = h
|
|
||||||
h = self.down[i_level].downsample(h)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
# h = hs[-1]
|
|
||||||
if self.enable_mid:
|
|
||||||
h = self.mid.block_1(h, temb)
|
|
||||||
hs["block_" + str(i_level) + "_atten"] = h
|
|
||||||
h = self.mid.attn_1(h)
|
|
||||||
h = self.mid.block_2(h, temb)
|
|
||||||
hs["mid_atten"] = h
|
|
||||||
|
|
||||||
# end
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.conv_out(h)
|
|
||||||
# hs.append(h)
|
|
||||||
hs["out"] = h
|
|
||||||
|
|
||||||
return hs
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadDecoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ch,
|
|
||||||
out_ch,
|
|
||||||
ch_mult=(1, 2, 4, 8),
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_resolutions=(16,),
|
|
||||||
dropout=0.0,
|
|
||||||
resamp_with_conv=True,
|
|
||||||
in_channels=3,
|
|
||||||
resolution=512,
|
|
||||||
z_channels=256,
|
|
||||||
give_pre_end=False,
|
|
||||||
enable_mid=True,
|
|
||||||
head_size=1,
|
|
||||||
**ignorekwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.ch = ch
|
|
||||||
self.temb_ch = 0
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.give_pre_end = give_pre_end
|
|
||||||
self.enable_mid = enable_mid
|
|
||||||
|
|
||||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
|
||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
|
||||||
print(
|
|
||||||
"Working with z of shape {} = {} dimensions.".format(
|
|
||||||
self.z_shape, np.prod(self.z_shape)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# z to block_in
|
|
||||||
self.conv_in = torch.nn.Conv2d(
|
|
||||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
if self.enable_mid:
|
|
||||||
self.mid = nn.Module()
|
|
||||||
self.mid.block_1 = ResnetBlock(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_in,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
|
||||||
self.mid.block_2 = ResnetBlock(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_in,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# upsampling
|
|
||||||
self.up = nn.ModuleList()
|
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
|
||||||
block = nn.ModuleList()
|
|
||||||
attn = nn.ModuleList()
|
|
||||||
block_out = ch * ch_mult[i_level]
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
block.append(
|
|
||||||
ResnetBlock(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_out,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
block_in = block_out
|
|
||||||
if curr_res in attn_resolutions:
|
|
||||||
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
|
||||||
up = nn.Module()
|
|
||||||
up.block = block
|
|
||||||
up.attn = attn
|
|
||||||
if i_level != 0:
|
|
||||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
self.up.insert(0, up) # prepend to get consistent order
|
|
||||||
|
|
||||||
# end
|
|
||||||
self.norm_out = Normalize(block_in)
|
|
||||||
self.conv_out = torch.nn.Conv2d(
|
|
||||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
# assert z.shape[1:] == self.z_shape[1:]
|
|
||||||
self.last_z_shape = z.shape
|
|
||||||
|
|
||||||
# timestep embedding
|
|
||||||
temb = None
|
|
||||||
|
|
||||||
# z to block_in
|
|
||||||
h = self.conv_in(z)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
if self.enable_mid:
|
|
||||||
h = self.mid.block_1(h, temb)
|
|
||||||
h = self.mid.attn_1(h)
|
|
||||||
h = self.mid.block_2(h, temb)
|
|
||||||
|
|
||||||
# upsampling
|
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
h = self.up[i_level].block[i_block](h, temb)
|
|
||||||
if len(self.up[i_level].attn) > 0:
|
|
||||||
h = self.up[i_level].attn[i_block](h)
|
|
||||||
if i_level != 0:
|
|
||||||
h = self.up[i_level].upsample(h)
|
|
||||||
|
|
||||||
# end
|
|
||||||
if self.give_pre_end:
|
|
||||||
return h
|
|
||||||
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.conv_out(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadDecoderTransformer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ch,
|
|
||||||
out_ch,
|
|
||||||
ch_mult=(1, 2, 4, 8),
|
|
||||||
num_res_blocks=2,
|
|
||||||
attn_resolutions=(16,),
|
|
||||||
dropout=0.0,
|
|
||||||
resamp_with_conv=True,
|
|
||||||
in_channels=3,
|
|
||||||
resolution=512,
|
|
||||||
z_channels=256,
|
|
||||||
give_pre_end=False,
|
|
||||||
enable_mid=True,
|
|
||||||
head_size=1,
|
|
||||||
**ignorekwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.ch = ch
|
|
||||||
self.temb_ch = 0
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.give_pre_end = give_pre_end
|
|
||||||
self.enable_mid = enable_mid
|
|
||||||
|
|
||||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
|
||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
|
||||||
print(
|
|
||||||
"Working with z of shape {} = {} dimensions.".format(
|
|
||||||
self.z_shape, np.prod(self.z_shape)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# z to block_in
|
|
||||||
self.conv_in = torch.nn.Conv2d(
|
|
||||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
if self.enable_mid:
|
|
||||||
self.mid = nn.Module()
|
|
||||||
self.mid.block_1 = ResnetBlock(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_in,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
|
||||||
self.mid.block_2 = ResnetBlock(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_in,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# upsampling
|
|
||||||
self.up = nn.ModuleList()
|
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
|
||||||
block = nn.ModuleList()
|
|
||||||
attn = nn.ModuleList()
|
|
||||||
block_out = ch * ch_mult[i_level]
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
block.append(
|
|
||||||
ResnetBlock(
|
|
||||||
in_channels=block_in,
|
|
||||||
out_channels=block_out,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
block_in = block_out
|
|
||||||
if curr_res in attn_resolutions:
|
|
||||||
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
|
||||||
up = nn.Module()
|
|
||||||
up.block = block
|
|
||||||
up.attn = attn
|
|
||||||
if i_level != 0:
|
|
||||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
self.up.insert(0, up) # prepend to get consistent order
|
|
||||||
|
|
||||||
# end
|
|
||||||
self.norm_out = Normalize(block_in)
|
|
||||||
self.conv_out = torch.nn.Conv2d(
|
|
||||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z, hs):
|
|
||||||
# assert z.shape[1:] == self.z_shape[1:]
|
|
||||||
# self.last_z_shape = z.shape
|
|
||||||
|
|
||||||
# timestep embedding
|
|
||||||
temb = None
|
|
||||||
|
|
||||||
# z to block_in
|
|
||||||
h = self.conv_in(z)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
if self.enable_mid:
|
|
||||||
h = self.mid.block_1(h, temb)
|
|
||||||
h = self.mid.attn_1(h, hs["mid_atten"])
|
|
||||||
h = self.mid.block_2(h, temb)
|
|
||||||
|
|
||||||
# upsampling
|
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
h = self.up[i_level].block[i_block](h, temb)
|
|
||||||
if len(self.up[i_level].attn) > 0:
|
|
||||||
h = self.up[i_level].attn[i_block](
|
|
||||||
h, hs["block_" + str(i_level) + "_atten"]
|
|
||||||
)
|
|
||||||
# hfeature = h.clone()
|
|
||||||
if i_level != 0:
|
|
||||||
h = self.up[i_level].upsample(h)
|
|
||||||
|
|
||||||
# end
|
|
||||||
if self.give_pre_end:
|
|
||||||
return h
|
|
||||||
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.conv_out(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class RestoreFormer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
):
|
|
||||||
super(RestoreFormer, self).__init__()
|
|
||||||
|
|
||||||
n_embed = 1024
|
|
||||||
embed_dim = 256
|
|
||||||
ch = 64
|
|
||||||
out_ch = 3
|
|
||||||
ch_mult = (1, 2, 2, 4, 4, 8)
|
|
||||||
num_res_blocks = 2
|
|
||||||
attn_resolutions = (16,)
|
|
||||||
dropout = 0.0
|
|
||||||
in_channels = 3
|
|
||||||
resolution = 512
|
|
||||||
z_channels = 256
|
|
||||||
double_z = False
|
|
||||||
enable_mid = True
|
|
||||||
fix_decoder = False
|
|
||||||
fix_codebook = True
|
|
||||||
fix_encoder = False
|
|
||||||
head_size = 8
|
|
||||||
|
|
||||||
self.model_arch = "RestoreFormer"
|
|
||||||
self.sub_type = "Face SR"
|
|
||||||
self.scale = 8
|
|
||||||
self.in_nc = 3
|
|
||||||
self.out_nc = out_ch
|
|
||||||
self.state = state_dict
|
|
||||||
|
|
||||||
self.supports_fp16 = False
|
|
||||||
self.supports_bf16 = True
|
|
||||||
self.min_size_restriction = 16
|
|
||||||
|
|
||||||
self.encoder = MultiHeadEncoder(
|
|
||||||
ch=ch,
|
|
||||||
out_ch=out_ch,
|
|
||||||
ch_mult=ch_mult,
|
|
||||||
num_res_blocks=num_res_blocks,
|
|
||||||
attn_resolutions=attn_resolutions,
|
|
||||||
dropout=dropout,
|
|
||||||
in_channels=in_channels,
|
|
||||||
resolution=resolution,
|
|
||||||
z_channels=z_channels,
|
|
||||||
double_z=double_z,
|
|
||||||
enable_mid=enable_mid,
|
|
||||||
head_size=head_size,
|
|
||||||
)
|
|
||||||
self.decoder = MultiHeadDecoderTransformer(
|
|
||||||
ch=ch,
|
|
||||||
out_ch=out_ch,
|
|
||||||
ch_mult=ch_mult,
|
|
||||||
num_res_blocks=num_res_blocks,
|
|
||||||
attn_resolutions=attn_resolutions,
|
|
||||||
dropout=dropout,
|
|
||||||
in_channels=in_channels,
|
|
||||||
resolution=resolution,
|
|
||||||
z_channels=z_channels,
|
|
||||||
enable_mid=enable_mid,
|
|
||||||
head_size=head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
|
|
||||||
|
|
||||||
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
|
|
||||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
|
||||||
|
|
||||||
if fix_decoder:
|
|
||||||
for _, param in self.decoder.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
for _, param in self.post_quant_conv.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
for _, param in self.quantize.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
elif fix_codebook:
|
|
||||||
for _, param in self.quantize.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
if fix_encoder:
|
|
||||||
for _, param in self.encoder.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
self.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
hs = self.encoder(x)
|
|
||||||
h = self.quant_conv(hs["out"])
|
|
||||||
quant, emb_loss, info = self.quantize(h)
|
|
||||||
return quant, emb_loss, info, hs
|
|
||||||
|
|
||||||
def decode(self, quant, hs):
|
|
||||||
quant = self.post_quant_conv(quant)
|
|
||||||
dec = self.decoder(quant, hs)
|
|
||||||
|
|
||||||
return dec
|
|
||||||
|
|
||||||
def forward(self, input, **kwargs):
|
|
||||||
quant, diff, info, hs = self.encode(input)
|
|
||||||
dec = self.decode(quant, hs)
|
|
||||||
|
|
||||||
return dec, None
|
|
||||||
@ -1,865 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# type: ignore
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
|
||||||
from .upfirdn2d import upfirdn2d
|
|
||||||
|
|
||||||
|
|
||||||
class NormStyleCode(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
"""Normalize the style codes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Style codes with shape (b, c).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Normalized tensor.
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
|
||||||
|
|
||||||
|
|
||||||
def make_resample_kernel(k):
|
|
||||||
"""Make resampling kernel for UpFirDn.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
k (list[int]): A list indicating the 1D resample kernel magnitude.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: 2D resampled kernel.
|
|
||||||
"""
|
|
||||||
k = torch.tensor(k, dtype=torch.float32)
|
|
||||||
if k.ndim == 1:
|
|
||||||
k = k[None, :] * k[:, None] # to 2D kernel, outer product
|
|
||||||
# normalize
|
|
||||||
k /= k.sum()
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
class UpFirDnUpsample(nn.Module):
|
|
||||||
"""Upsample, FIR filter, and downsample (upsampole version).
|
|
||||||
|
|
||||||
References:
|
|
||||||
1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
|
|
||||||
2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
|
||||||
magnitude.
|
|
||||||
factor (int): Upsampling scale factor. Default: 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, resample_kernel, factor=2):
|
|
||||||
super(UpFirDnUpsample, self).__init__()
|
|
||||||
self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
|
|
||||||
self.factor = factor
|
|
||||||
|
|
||||||
pad = self.kernel.shape[0] - factor
|
|
||||||
self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(factor={self.factor})"
|
|
||||||
|
|
||||||
|
|
||||||
class UpFirDnDownsample(nn.Module):
|
|
||||||
"""Upsample, FIR filter, and downsample (downsampole version).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
|
||||||
magnitude.
|
|
||||||
factor (int): Downsampling scale factor. Default: 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, resample_kernel, factor=2):
|
|
||||||
super(UpFirDnDownsample, self).__init__()
|
|
||||||
self.kernel = make_resample_kernel(resample_kernel)
|
|
||||||
self.factor = factor
|
|
||||||
|
|
||||||
pad = self.kernel.shape[0] - factor
|
|
||||||
self.pad = ((pad + 1) // 2, pad // 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(factor={self.factor})"
|
|
||||||
|
|
||||||
|
|
||||||
class UpFirDnSmooth(nn.Module):
|
|
||||||
"""Upsample, FIR filter, and downsample (smooth version).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
|
||||||
magnitude.
|
|
||||||
upsample_factor (int): Upsampling scale factor. Default: 1.
|
|
||||||
downsample_factor (int): Downsampling scale factor. Default: 1.
|
|
||||||
kernel_size (int): Kernel size: Default: 1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1
|
|
||||||
):
|
|
||||||
super(UpFirDnSmooth, self).__init__()
|
|
||||||
self.upsample_factor = upsample_factor
|
|
||||||
self.downsample_factor = downsample_factor
|
|
||||||
self.kernel = make_resample_kernel(resample_kernel)
|
|
||||||
if upsample_factor > 1:
|
|
||||||
self.kernel = self.kernel * (upsample_factor**2)
|
|
||||||
|
|
||||||
if upsample_factor > 1:
|
|
||||||
pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
|
|
||||||
self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
|
|
||||||
elif downsample_factor > 1:
|
|
||||||
pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
|
|
||||||
self.pad = ((pad + 1) // 2, pad // 2)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}(upsample_factor={self.upsample_factor}"
|
|
||||||
f", downsample_factor={self.downsample_factor})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EqualLinear(nn.Module):
|
|
||||||
"""Equalized Linear as StyleGAN2.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): Size of each sample.
|
|
||||||
out_channels (int): Size of each output sample.
|
|
||||||
bias (bool): If set to ``False``, the layer will not learn an additive
|
|
||||||
bias. Default: ``True``.
|
|
||||||
bias_init_val (float): Bias initialized value. Default: 0.
|
|
||||||
lr_mul (float): Learning rate multiplier. Default: 1.
|
|
||||||
activation (None | str): The activation after ``linear`` operation.
|
|
||||||
Supported: 'fused_lrelu', None. Default: None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
lr_mul=1,
|
|
||||||
activation=None,
|
|
||||||
):
|
|
||||||
super(EqualLinear, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.lr_mul = lr_mul
|
|
||||||
self.activation = activation
|
|
||||||
if self.activation not in ["fused_lrelu", None]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Wrong activation value in EqualLinear: {activation}"
|
|
||||||
"Supported ones are: ['fused_lrelu', None]."
|
|
||||||
)
|
|
||||||
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.bias is None:
|
|
||||||
bias = None
|
|
||||||
else:
|
|
||||||
bias = self.bias * self.lr_mul
|
|
||||||
if self.activation == "fused_lrelu":
|
|
||||||
out = F.linear(x, self.weight * self.scale)
|
|
||||||
out = fused_leaky_relu(out, bias)
|
|
||||||
else:
|
|
||||||
out = F.linear(x, self.weight * self.scale, bias=bias)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}(in_channels={self.in_channels}, "
|
|
||||||
f"out_channels={self.out_channels}, bias={self.bias is not None})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModulatedConv2d(nn.Module):
|
|
||||||
"""Modulated Conv2d used in StyleGAN2.
|
|
||||||
|
|
||||||
There is no bias in ModulatedConv2d.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Size of the convolving kernel.
|
|
||||||
num_style_feat (int): Channel number of style features.
|
|
||||||
demodulate (bool): Whether to demodulate in the conv layer.
|
|
||||||
Default: True.
|
|
||||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
|
||||||
Default: None.
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
|
||||||
magnitude. Default: (1, 3, 3, 1).
|
|
||||||
eps (float): A value added to the denominator for numerical stability.
|
|
||||||
Default: 1e-8.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
eps=1e-8,
|
|
||||||
):
|
|
||||||
super(ModulatedConv2d, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.demodulate = demodulate
|
|
||||||
self.sample_mode = sample_mode
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
if self.sample_mode == "upsample":
|
|
||||||
self.smooth = UpFirDnSmooth(
|
|
||||||
resample_kernel,
|
|
||||||
upsample_factor=2,
|
|
||||||
downsample_factor=1,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
)
|
|
||||||
elif self.sample_mode == "downsample":
|
|
||||||
self.smooth = UpFirDnSmooth(
|
|
||||||
resample_kernel,
|
|
||||||
upsample_factor=1,
|
|
||||||
downsample_factor=2,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
)
|
|
||||||
elif self.sample_mode is None:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Wrong sample mode {self.sample_mode}, "
|
|
||||||
"supported ones are ['upsample', 'downsample', None]."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
|
||||||
# modulation inside each modulated conv
|
|
||||||
self.modulation = EqualLinear(
|
|
||||||
num_style_feat,
|
|
||||||
in_channels,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=1,
|
|
||||||
lr_mul=1,
|
|
||||||
activation=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(
|
|
||||||
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
|
|
||||||
)
|
|
||||||
self.padding = kernel_size // 2
|
|
||||||
|
|
||||||
def forward(self, x, style):
|
|
||||||
"""Forward function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Tensor with shape (b, c, h, w).
|
|
||||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Modulated tensor after convolution.
|
|
||||||
"""
|
|
||||||
b, c, h, w = x.shape # c = c_in
|
|
||||||
# weight modulation
|
|
||||||
style = self.modulation(style).view(b, 1, c, 1, 1)
|
|
||||||
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
|
||||||
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
|
||||||
|
|
||||||
if self.demodulate:
|
|
||||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
|
||||||
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
|
||||||
|
|
||||||
weight = weight.view(
|
|
||||||
b * self.out_channels, c, self.kernel_size, self.kernel_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.sample_mode == "upsample":
|
|
||||||
x = x.view(1, b * c, h, w)
|
|
||||||
weight = weight.view(
|
|
||||||
b, self.out_channels, c, self.kernel_size, self.kernel_size
|
|
||||||
)
|
|
||||||
weight = weight.transpose(1, 2).reshape(
|
|
||||||
b * c, self.out_channels, self.kernel_size, self.kernel_size
|
|
||||||
)
|
|
||||||
out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
|
|
||||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
|
||||||
out = self.smooth(out)
|
|
||||||
elif self.sample_mode == "downsample":
|
|
||||||
x = self.smooth(x)
|
|
||||||
x = x.view(1, b * c, *x.shape[2:4])
|
|
||||||
out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
|
|
||||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
|
||||||
else:
|
|
||||||
x = x.view(1, b * c, h, w)
|
|
||||||
# weight: (b*c_out, c_in, k, k), groups=b
|
|
||||||
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
|
||||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}(in_channels={self.in_channels}, "
|
|
||||||
f"out_channels={self.out_channels}, "
|
|
||||||
f"kernel_size={self.kernel_size}, "
|
|
||||||
f"demodulate={self.demodulate}, sample_mode={self.sample_mode})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StyleConv(nn.Module):
|
|
||||||
"""Style conv.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Size of the convolving kernel.
|
|
||||||
num_style_feat (int): Channel number of style features.
|
|
||||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
|
||||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
|
||||||
Default: None.
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
|
||||||
magnitude. Default: (1, 3, 3, 1).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
):
|
|
||||||
super(StyleConv, self).__init__()
|
|
||||||
self.modulated_conv = ModulatedConv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
num_style_feat,
|
|
||||||
demodulate=demodulate,
|
|
||||||
sample_mode=sample_mode,
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
)
|
|
||||||
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
|
||||||
self.activate = FusedLeakyReLU(out_channels)
|
|
||||||
|
|
||||||
def forward(self, x, style, noise=None):
|
|
||||||
# modulate
|
|
||||||
out = self.modulated_conv(x, style)
|
|
||||||
# noise injection
|
|
||||||
if noise is None:
|
|
||||||
b, _, h, w = out.shape
|
|
||||||
noise = out.new_empty(b, 1, h, w).normal_()
|
|
||||||
out = out + self.weight * noise
|
|
||||||
# activation (with bias)
|
|
||||||
out = self.activate(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ToRGB(nn.Module):
|
|
||||||
"""To RGB from features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of input.
|
|
||||||
num_style_feat (int): Channel number of style features.
|
|
||||||
upsample (bool): Whether to upsample. Default: True.
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
|
||||||
magnitude. Default: (1, 3, 3, 1).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)
|
|
||||||
):
|
|
||||||
super(ToRGB, self).__init__()
|
|
||||||
if upsample:
|
|
||||||
self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
|
|
||||||
else:
|
|
||||||
self.upsample = None
|
|
||||||
self.modulated_conv = ModulatedConv2d(
|
|
||||||
in_channels,
|
|
||||||
3,
|
|
||||||
kernel_size=1,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=False,
|
|
||||||
sample_mode=None,
|
|
||||||
)
|
|
||||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
|
||||||
|
|
||||||
def forward(self, x, style, skip=None):
|
|
||||||
"""Forward function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Feature tensor with shape (b, c, h, w).
|
|
||||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
|
||||||
skip (Tensor): Base/skip tensor. Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: RGB images.
|
|
||||||
"""
|
|
||||||
out = self.modulated_conv(x, style)
|
|
||||||
out = out + self.bias
|
|
||||||
if skip is not None:
|
|
||||||
if self.upsample:
|
|
||||||
skip = self.upsample(skip)
|
|
||||||
out = out + skip
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ConstantInput(nn.Module):
|
|
||||||
"""Constant input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_channel (int): Channel number of constant input.
|
|
||||||
size (int): Spatial size of constant input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_channel, size):
|
|
||||||
super(ConstantInput, self).__init__()
|
|
||||||
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
|
||||||
|
|
||||||
def forward(self, batch):
|
|
||||||
out = self.weight.repeat(batch, 1, 1, 1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class StyleGAN2Generator(nn.Module):
|
|
||||||
"""StyleGAN2 Generator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
out_size (int): The spatial size of outputs.
|
|
||||||
num_style_feat (int): Channel number of style features. Default: 512.
|
|
||||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
|
||||||
channel_multiplier (int): Channel multiplier for large networks of
|
|
||||||
StyleGAN2. Default: 2.
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
|
||||||
magnitude. A cross production will be applied to extent 1D resample
|
|
||||||
kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
|
||||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
|
||||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
out_size,
|
|
||||||
num_style_feat=512,
|
|
||||||
num_mlp=8,
|
|
||||||
channel_multiplier=2,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
lr_mlp=0.01,
|
|
||||||
narrow=1,
|
|
||||||
):
|
|
||||||
super(StyleGAN2Generator, self).__init__()
|
|
||||||
# Style MLP layers
|
|
||||||
self.num_style_feat = num_style_feat
|
|
||||||
style_mlp_layers = [NormStyleCode()]
|
|
||||||
for i in range(num_mlp):
|
|
||||||
style_mlp_layers.append(
|
|
||||||
EqualLinear(
|
|
||||||
num_style_feat,
|
|
||||||
num_style_feat,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
lr_mul=lr_mlp,
|
|
||||||
activation="fused_lrelu",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
|
||||||
|
|
||||||
channels = {
|
|
||||||
"4": int(512 * narrow),
|
|
||||||
"8": int(512 * narrow),
|
|
||||||
"16": int(512 * narrow),
|
|
||||||
"32": int(512 * narrow),
|
|
||||||
"64": int(256 * channel_multiplier * narrow),
|
|
||||||
"128": int(128 * channel_multiplier * narrow),
|
|
||||||
"256": int(64 * channel_multiplier * narrow),
|
|
||||||
"512": int(32 * channel_multiplier * narrow),
|
|
||||||
"1024": int(16 * channel_multiplier * narrow),
|
|
||||||
}
|
|
||||||
self.channels = channels
|
|
||||||
|
|
||||||
self.constant_input = ConstantInput(channels["4"], size=4)
|
|
||||||
self.style_conv1 = StyleConv(
|
|
||||||
channels["4"],
|
|
||||||
channels["4"],
|
|
||||||
kernel_size=3,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
)
|
|
||||||
self.to_rgb1 = ToRGB(
|
|
||||||
channels["4"],
|
|
||||||
num_style_feat,
|
|
||||||
upsample=False,
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log_size = int(math.log(out_size, 2))
|
|
||||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
|
||||||
self.num_latent = self.log_size * 2 - 2
|
|
||||||
|
|
||||||
self.style_convs = nn.ModuleList()
|
|
||||||
self.to_rgbs = nn.ModuleList()
|
|
||||||
self.noises = nn.Module()
|
|
||||||
|
|
||||||
in_channels = channels["4"]
|
|
||||||
# noise
|
|
||||||
for layer_idx in range(self.num_layers):
|
|
||||||
resolution = 2 ** ((layer_idx + 5) // 2)
|
|
||||||
shape = [1, 1, resolution, resolution]
|
|
||||||
self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
|
|
||||||
# style convs and to_rgbs
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
out_channels = channels[f"{2**i}"]
|
|
||||||
self.style_convs.append(
|
|
||||||
StyleConv(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode="upsample",
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.style_convs.append(
|
|
||||||
StyleConv(
|
|
||||||
out_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.to_rgbs.append(
|
|
||||||
ToRGB(
|
|
||||||
out_channels,
|
|
||||||
num_style_feat,
|
|
||||||
upsample=True,
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
def make_noise(self):
|
|
||||||
"""Make noise for noise injection."""
|
|
||||||
device = self.constant_input.weight.device
|
|
||||||
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
|
||||||
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
for _ in range(2):
|
|
||||||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
|
||||||
|
|
||||||
return noises
|
|
||||||
|
|
||||||
def get_latent(self, x):
|
|
||||||
return self.style_mlp(x)
|
|
||||||
|
|
||||||
def mean_latent(self, num_latent):
|
|
||||||
latent_in = torch.randn(
|
|
||||||
num_latent, self.num_style_feat, device=self.constant_input.weight.device
|
|
||||||
)
|
|
||||||
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
|
||||||
return latent
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
styles,
|
|
||||||
input_is_latent=False,
|
|
||||||
noise=None,
|
|
||||||
randomize_noise=True,
|
|
||||||
truncation=1,
|
|
||||||
truncation_latent=None,
|
|
||||||
inject_index=None,
|
|
||||||
return_latents=False,
|
|
||||||
):
|
|
||||||
"""Forward function for StyleGAN2Generator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
styles (list[Tensor]): Sample codes of styles.
|
|
||||||
input_is_latent (bool): Whether input is latent style.
|
|
||||||
Default: False.
|
|
||||||
noise (Tensor | None): Input noise or None. Default: None.
|
|
||||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
|
||||||
False. Default: True.
|
|
||||||
truncation (float): TODO. Default: 1.
|
|
||||||
truncation_latent (Tensor | None): TODO. Default: None.
|
|
||||||
inject_index (int | None): The injection index for mixing noise.
|
|
||||||
Default: None.
|
|
||||||
return_latents (bool): Whether to return style latents.
|
|
||||||
Default: False.
|
|
||||||
"""
|
|
||||||
# style codes -> latents with Style MLP layer
|
|
||||||
if not input_is_latent:
|
|
||||||
styles = [self.style_mlp(s) for s in styles]
|
|
||||||
# noises
|
|
||||||
if noise is None:
|
|
||||||
if randomize_noise:
|
|
||||||
noise = [None] * self.num_layers # for each style conv layer
|
|
||||||
else: # use the stored noise
|
|
||||||
noise = [
|
|
||||||
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
|
|
||||||
]
|
|
||||||
# style truncation
|
|
||||||
if truncation < 1:
|
|
||||||
style_truncation = []
|
|
||||||
for style in styles:
|
|
||||||
style_truncation.append(
|
|
||||||
truncation_latent + truncation * (style - truncation_latent)
|
|
||||||
)
|
|
||||||
styles = style_truncation
|
|
||||||
# get style latent with injection
|
|
||||||
if len(styles) == 1:
|
|
||||||
inject_index = self.num_latent
|
|
||||||
|
|
||||||
if styles[0].ndim < 3:
|
|
||||||
# repeat latent code for all the layers
|
|
||||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
else: # used for encoder with different latent code for each layer
|
|
||||||
latent = styles[0]
|
|
||||||
elif len(styles) == 2: # mixing noises
|
|
||||||
if inject_index is None:
|
|
||||||
inject_index = random.randint(1, self.num_latent - 1)
|
|
||||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
latent2 = (
|
|
||||||
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
|
||||||
)
|
|
||||||
latent = torch.cat([latent1, latent2], 1)
|
|
||||||
|
|
||||||
# main generation
|
|
||||||
out = self.constant_input(latent.shape[0])
|
|
||||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
|
||||||
skip = self.to_rgb1(out, latent[:, 1])
|
|
||||||
|
|
||||||
i = 1
|
|
||||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
|
||||||
self.style_convs[::2],
|
|
||||||
self.style_convs[1::2],
|
|
||||||
noise[1::2],
|
|
||||||
noise[2::2],
|
|
||||||
self.to_rgbs,
|
|
||||||
):
|
|
||||||
out = conv1(out, latent[:, i], noise=noise1)
|
|
||||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
|
||||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
|
||||||
i += 2
|
|
||||||
|
|
||||||
image = skip
|
|
||||||
|
|
||||||
if return_latents:
|
|
||||||
return image, latent
|
|
||||||
else:
|
|
||||||
return image, None
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledLeakyReLU(nn.Module):
|
|
||||||
"""Scaled LeakyReLU.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
negative_slope (float): Negative slope. Default: 0.2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, negative_slope=0.2):
|
|
||||||
super(ScaledLeakyReLU, self).__init__()
|
|
||||||
self.negative_slope = negative_slope
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
|
||||||
return out * math.sqrt(2)
|
|
||||||
|
|
||||||
|
|
||||||
class EqualConv2d(nn.Module):
|
|
||||||
"""Equalized Linear as StyleGAN2.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Size of the convolving kernel.
|
|
||||||
stride (int): Stride of the convolution. Default: 1
|
|
||||||
padding (int): Zero-padding added to both sides of the input.
|
|
||||||
Default: 0.
|
|
||||||
bias (bool): If ``True``, adds a learnable bias to the output.
|
|
||||||
Default: ``True``.
|
|
||||||
bias_init_val (float): Bias initialized value. Default: 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
):
|
|
||||||
super(EqualConv2d, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(
|
|
||||||
torch.randn(out_channels, in_channels, kernel_size, kernel_size)
|
|
||||||
)
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = F.conv2d(
|
|
||||||
x,
|
|
||||||
self.weight * self.scale,
|
|
||||||
bias=self.bias,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=self.padding,
|
|
||||||
)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}(in_channels={self.in_channels}, "
|
|
||||||
f"out_channels={self.out_channels}, "
|
|
||||||
f"kernel_size={self.kernel_size},"
|
|
||||||
f" stride={self.stride}, padding={self.padding}, "
|
|
||||||
f"bias={self.bias is not None})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvLayer(nn.Sequential):
|
|
||||||
"""Conv Layer used in StyleGAN2 Discriminator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Kernel size.
|
|
||||||
downsample (bool): Whether downsample by a factor of 2.
|
|
||||||
Default: False.
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample
|
|
||||||
kernel magnitude. A cross production will be applied to
|
|
||||||
extent 1D resample kernel to 2D resample kernel.
|
|
||||||
Default: (1, 3, 3, 1).
|
|
||||||
bias (bool): Whether with bias. Default: True.
|
|
||||||
activate (bool): Whether use activateion. Default: True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
downsample=False,
|
|
||||||
resample_kernel=(1, 3, 3, 1),
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
):
|
|
||||||
layers = []
|
|
||||||
# downsample
|
|
||||||
if downsample:
|
|
||||||
layers.append(
|
|
||||||
UpFirDnSmooth(
|
|
||||||
resample_kernel,
|
|
||||||
upsample_factor=1,
|
|
||||||
downsample_factor=2,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stride = 2
|
|
||||||
self.padding = 0
|
|
||||||
else:
|
|
||||||
stride = 1
|
|
||||||
self.padding = kernel_size // 2
|
|
||||||
# conv
|
|
||||||
layers.append(
|
|
||||||
EqualConv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=self.padding,
|
|
||||||
bias=bias and not activate,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# activation
|
|
||||||
if activate:
|
|
||||||
if bias:
|
|
||||||
layers.append(FusedLeakyReLU(out_channels))
|
|
||||||
else:
|
|
||||||
layers.append(ScaledLeakyReLU(0.2))
|
|
||||||
|
|
||||||
super(ConvLayer, self).__init__(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
"""Residual block used in StyleGAN2 Discriminator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
resample_kernel (list[int]): A list indicating the 1D resample
|
|
||||||
kernel magnitude. A cross production will be applied to
|
|
||||||
extent 1D resample kernel to 2D resample kernel.
|
|
||||||
Default: (1, 3, 3, 1).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
|
|
||||||
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
|
||||||
self.conv2 = ConvLayer(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
3,
|
|
||||||
downsample=True,
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
)
|
|
||||||
self.skip = ConvLayer(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
1,
|
|
||||||
downsample=True,
|
|
||||||
resample_kernel=resample_kernel,
|
|
||||||
bias=False,
|
|
||||||
activate=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.conv2(out)
|
|
||||||
skip = self.skip(x)
|
|
||||||
out = (out + skip) / math.sqrt(2)
|
|
||||||
return out
|
|
||||||
@ -1,709 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# type: ignore
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
|
||||||
|
|
||||||
|
|
||||||
class NormStyleCode(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
"""Normalize the style codes.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Style codes with shape (b, c).
|
|
||||||
Returns:
|
|
||||||
Tensor: Normalized tensor.
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
|
||||||
|
|
||||||
|
|
||||||
class EqualLinear(nn.Module):
|
|
||||||
"""Equalized Linear as StyleGAN2.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Size of each sample.
|
|
||||||
out_channels (int): Size of each output sample.
|
|
||||||
bias (bool): If set to ``False``, the layer will not learn an additive
|
|
||||||
bias. Default: ``True``.
|
|
||||||
bias_init_val (float): Bias initialized value. Default: 0.
|
|
||||||
lr_mul (float): Learning rate multiplier. Default: 1.
|
|
||||||
activation (None | str): The activation after ``linear`` operation.
|
|
||||||
Supported: 'fused_lrelu', None. Default: None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
lr_mul=1,
|
|
||||||
activation=None,
|
|
||||||
):
|
|
||||||
super(EqualLinear, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.lr_mul = lr_mul
|
|
||||||
self.activation = activation
|
|
||||||
if self.activation not in ["fused_lrelu", None]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Wrong activation value in EqualLinear: {activation}"
|
|
||||||
"Supported ones are: ['fused_lrelu', None]."
|
|
||||||
)
|
|
||||||
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.bias is None:
|
|
||||||
bias = None
|
|
||||||
else:
|
|
||||||
bias = self.bias * self.lr_mul
|
|
||||||
if self.activation == "fused_lrelu":
|
|
||||||
out = F.linear(x, self.weight * self.scale)
|
|
||||||
out = fused_leaky_relu(out, bias)
|
|
||||||
else:
|
|
||||||
out = F.linear(x, self.weight * self.scale, bias=bias)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}(in_channels={self.in_channels}, "
|
|
||||||
f"out_channels={self.out_channels}, bias={self.bias is not None})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModulatedConv2d(nn.Module):
|
|
||||||
"""Modulated Conv2d used in StyleGAN2.
|
|
||||||
There is no bias in ModulatedConv2d.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Size of the convolving kernel.
|
|
||||||
num_style_feat (int): Channel number of style features.
|
|
||||||
demodulate (bool): Whether to demodulate in the conv layer.
|
|
||||||
Default: True.
|
|
||||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
|
||||||
Default: None.
|
|
||||||
eps (float): A value added to the denominator for numerical stability.
|
|
||||||
Default: 1e-8.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
eps=1e-8,
|
|
||||||
interpolation_mode="bilinear",
|
|
||||||
):
|
|
||||||
super(ModulatedConv2d, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.demodulate = demodulate
|
|
||||||
self.sample_mode = sample_mode
|
|
||||||
self.eps = eps
|
|
||||||
self.interpolation_mode = interpolation_mode
|
|
||||||
if self.interpolation_mode == "nearest":
|
|
||||||
self.align_corners = None
|
|
||||||
else:
|
|
||||||
self.align_corners = False
|
|
||||||
|
|
||||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
|
||||||
# modulation inside each modulated conv
|
|
||||||
self.modulation = EqualLinear(
|
|
||||||
num_style_feat,
|
|
||||||
in_channels,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=1,
|
|
||||||
lr_mul=1,
|
|
||||||
activation=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(
|
|
||||||
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
|
|
||||||
)
|
|
||||||
self.padding = kernel_size // 2
|
|
||||||
|
|
||||||
def forward(self, x, style):
|
|
||||||
"""Forward function.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Tensor with shape (b, c, h, w).
|
|
||||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
|
||||||
Returns:
|
|
||||||
Tensor: Modulated tensor after convolution.
|
|
||||||
"""
|
|
||||||
b, c, h, w = x.shape # c = c_in
|
|
||||||
# weight modulation
|
|
||||||
style = self.modulation(style).view(b, 1, c, 1, 1)
|
|
||||||
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
|
||||||
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
|
||||||
|
|
||||||
if self.demodulate:
|
|
||||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
|
||||||
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
|
||||||
|
|
||||||
weight = weight.view(
|
|
||||||
b * self.out_channels, c, self.kernel_size, self.kernel_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.sample_mode == "upsample":
|
|
||||||
x = F.interpolate(
|
|
||||||
x,
|
|
||||||
scale_factor=2,
|
|
||||||
mode=self.interpolation_mode,
|
|
||||||
align_corners=self.align_corners,
|
|
||||||
)
|
|
||||||
elif self.sample_mode == "downsample":
|
|
||||||
x = F.interpolate(
|
|
||||||
x,
|
|
||||||
scale_factor=0.5,
|
|
||||||
mode=self.interpolation_mode,
|
|
||||||
align_corners=self.align_corners,
|
|
||||||
)
|
|
||||||
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
x = x.view(1, b * c, h, w)
|
|
||||||
# weight: (b*c_out, c_in, k, k), groups=b
|
|
||||||
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
|
||||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}(in_channels={self.in_channels}, "
|
|
||||||
f"out_channels={self.out_channels}, "
|
|
||||||
f"kernel_size={self.kernel_size}, "
|
|
||||||
f"demodulate={self.demodulate}, sample_mode={self.sample_mode})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StyleConv(nn.Module):
|
|
||||||
"""Style conv.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Size of the convolving kernel.
|
|
||||||
num_style_feat (int): Channel number of style features.
|
|
||||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
|
||||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
|
||||||
Default: None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
interpolation_mode="bilinear",
|
|
||||||
):
|
|
||||||
super(StyleConv, self).__init__()
|
|
||||||
self.modulated_conv = ModulatedConv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
num_style_feat,
|
|
||||||
demodulate=demodulate,
|
|
||||||
sample_mode=sample_mode,
|
|
||||||
interpolation_mode=interpolation_mode,
|
|
||||||
)
|
|
||||||
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
|
||||||
self.activate = FusedLeakyReLU(out_channels)
|
|
||||||
|
|
||||||
def forward(self, x, style, noise=None):
|
|
||||||
# modulate
|
|
||||||
out = self.modulated_conv(x, style)
|
|
||||||
# noise injection
|
|
||||||
if noise is None:
|
|
||||||
b, _, h, w = out.shape
|
|
||||||
noise = out.new_empty(b, 1, h, w).normal_()
|
|
||||||
out = out + self.weight * noise
|
|
||||||
# activation (with bias)
|
|
||||||
out = self.activate(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ToRGB(nn.Module):
|
|
||||||
"""To RGB from features.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of input.
|
|
||||||
num_style_feat (int): Channel number of style features.
|
|
||||||
upsample (bool): Whether to upsample. Default: True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, in_channels, num_style_feat, upsample=True, interpolation_mode="bilinear"
|
|
||||||
):
|
|
||||||
super(ToRGB, self).__init__()
|
|
||||||
self.upsample = upsample
|
|
||||||
self.interpolation_mode = interpolation_mode
|
|
||||||
if self.interpolation_mode == "nearest":
|
|
||||||
self.align_corners = None
|
|
||||||
else:
|
|
||||||
self.align_corners = False
|
|
||||||
self.modulated_conv = ModulatedConv2d(
|
|
||||||
in_channels,
|
|
||||||
3,
|
|
||||||
kernel_size=1,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=False,
|
|
||||||
sample_mode=None,
|
|
||||||
interpolation_mode=interpolation_mode,
|
|
||||||
)
|
|
||||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
|
||||||
|
|
||||||
def forward(self, x, style, skip=None):
|
|
||||||
"""Forward function.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Feature tensor with shape (b, c, h, w).
|
|
||||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
|
||||||
skip (Tensor): Base/skip tensor. Default: None.
|
|
||||||
Returns:
|
|
||||||
Tensor: RGB images.
|
|
||||||
"""
|
|
||||||
out = self.modulated_conv(x, style)
|
|
||||||
out = out + self.bias
|
|
||||||
if skip is not None:
|
|
||||||
if self.upsample:
|
|
||||||
skip = F.interpolate(
|
|
||||||
skip,
|
|
||||||
scale_factor=2,
|
|
||||||
mode=self.interpolation_mode,
|
|
||||||
align_corners=self.align_corners,
|
|
||||||
)
|
|
||||||
out = out + skip
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ConstantInput(nn.Module):
|
|
||||||
"""Constant input.
|
|
||||||
Args:
|
|
||||||
num_channel (int): Channel number of constant input.
|
|
||||||
size (int): Spatial size of constant input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_channel, size):
|
|
||||||
super(ConstantInput, self).__init__()
|
|
||||||
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
|
||||||
|
|
||||||
def forward(self, batch):
|
|
||||||
out = self.weight.repeat(batch, 1, 1, 1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class StyleGAN2GeneratorBilinear(nn.Module):
|
|
||||||
"""StyleGAN2 Generator.
|
|
||||||
Args:
|
|
||||||
out_size (int): The spatial size of outputs.
|
|
||||||
num_style_feat (int): Channel number of style features. Default: 512.
|
|
||||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
|
||||||
channel_multiplier (int): Channel multiplier for large networks of
|
|
||||||
StyleGAN2. Default: 2.
|
|
||||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
|
||||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
out_size,
|
|
||||||
num_style_feat=512,
|
|
||||||
num_mlp=8,
|
|
||||||
channel_multiplier=2,
|
|
||||||
lr_mlp=0.01,
|
|
||||||
narrow=1,
|
|
||||||
interpolation_mode="bilinear",
|
|
||||||
):
|
|
||||||
super(StyleGAN2GeneratorBilinear, self).__init__()
|
|
||||||
# Style MLP layers
|
|
||||||
self.num_style_feat = num_style_feat
|
|
||||||
style_mlp_layers = [NormStyleCode()]
|
|
||||||
for i in range(num_mlp):
|
|
||||||
style_mlp_layers.append(
|
|
||||||
EqualLinear(
|
|
||||||
num_style_feat,
|
|
||||||
num_style_feat,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
lr_mul=lr_mlp,
|
|
||||||
activation="fused_lrelu",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
|
||||||
|
|
||||||
channels = {
|
|
||||||
"4": int(512 * narrow),
|
|
||||||
"8": int(512 * narrow),
|
|
||||||
"16": int(512 * narrow),
|
|
||||||
"32": int(512 * narrow),
|
|
||||||
"64": int(256 * channel_multiplier * narrow),
|
|
||||||
"128": int(128 * channel_multiplier * narrow),
|
|
||||||
"256": int(64 * channel_multiplier * narrow),
|
|
||||||
"512": int(32 * channel_multiplier * narrow),
|
|
||||||
"1024": int(16 * channel_multiplier * narrow),
|
|
||||||
}
|
|
||||||
self.channels = channels
|
|
||||||
|
|
||||||
self.constant_input = ConstantInput(channels["4"], size=4)
|
|
||||||
self.style_conv1 = StyleConv(
|
|
||||||
channels["4"],
|
|
||||||
channels["4"],
|
|
||||||
kernel_size=3,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
interpolation_mode=interpolation_mode,
|
|
||||||
)
|
|
||||||
self.to_rgb1 = ToRGB(
|
|
||||||
channels["4"],
|
|
||||||
num_style_feat,
|
|
||||||
upsample=False,
|
|
||||||
interpolation_mode=interpolation_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log_size = int(math.log(out_size, 2))
|
|
||||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
|
||||||
self.num_latent = self.log_size * 2 - 2
|
|
||||||
|
|
||||||
self.style_convs = nn.ModuleList()
|
|
||||||
self.to_rgbs = nn.ModuleList()
|
|
||||||
self.noises = nn.Module()
|
|
||||||
|
|
||||||
in_channels = channels["4"]
|
|
||||||
# noise
|
|
||||||
for layer_idx in range(self.num_layers):
|
|
||||||
resolution = 2 ** ((layer_idx + 5) // 2)
|
|
||||||
shape = [1, 1, resolution, resolution]
|
|
||||||
self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
|
|
||||||
# style convs and to_rgbs
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
out_channels = channels[f"{2**i}"]
|
|
||||||
self.style_convs.append(
|
|
||||||
StyleConv(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode="upsample",
|
|
||||||
interpolation_mode=interpolation_mode,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.style_convs.append(
|
|
||||||
StyleConv(
|
|
||||||
out_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
interpolation_mode=interpolation_mode,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.to_rgbs.append(
|
|
||||||
ToRGB(
|
|
||||||
out_channels,
|
|
||||||
num_style_feat,
|
|
||||||
upsample=True,
|
|
||||||
interpolation_mode=interpolation_mode,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
def make_noise(self):
|
|
||||||
"""Make noise for noise injection."""
|
|
||||||
device = self.constant_input.weight.device
|
|
||||||
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
|
||||||
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
for _ in range(2):
|
|
||||||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
|
||||||
|
|
||||||
return noises
|
|
||||||
|
|
||||||
def get_latent(self, x):
|
|
||||||
return self.style_mlp(x)
|
|
||||||
|
|
||||||
def mean_latent(self, num_latent):
|
|
||||||
latent_in = torch.randn(
|
|
||||||
num_latent, self.num_style_feat, device=self.constant_input.weight.device
|
|
||||||
)
|
|
||||||
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
|
||||||
return latent
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
styles,
|
|
||||||
input_is_latent=False,
|
|
||||||
noise=None,
|
|
||||||
randomize_noise=True,
|
|
||||||
truncation=1,
|
|
||||||
truncation_latent=None,
|
|
||||||
inject_index=None,
|
|
||||||
return_latents=False,
|
|
||||||
):
|
|
||||||
"""Forward function for StyleGAN2Generator.
|
|
||||||
Args:
|
|
||||||
styles (list[Tensor]): Sample codes of styles.
|
|
||||||
input_is_latent (bool): Whether input is latent style.
|
|
||||||
Default: False.
|
|
||||||
noise (Tensor | None): Input noise or None. Default: None.
|
|
||||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
|
||||||
False. Default: True.
|
|
||||||
truncation (float): TODO. Default: 1.
|
|
||||||
truncation_latent (Tensor | None): TODO. Default: None.
|
|
||||||
inject_index (int | None): The injection index for mixing noise.
|
|
||||||
Default: None.
|
|
||||||
return_latents (bool): Whether to return style latents.
|
|
||||||
Default: False.
|
|
||||||
"""
|
|
||||||
# style codes -> latents with Style MLP layer
|
|
||||||
if not input_is_latent:
|
|
||||||
styles = [self.style_mlp(s) for s in styles]
|
|
||||||
# noises
|
|
||||||
if noise is None:
|
|
||||||
if randomize_noise:
|
|
||||||
noise = [None] * self.num_layers # for each style conv layer
|
|
||||||
else: # use the stored noise
|
|
||||||
noise = [
|
|
||||||
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
|
|
||||||
]
|
|
||||||
# style truncation
|
|
||||||
if truncation < 1:
|
|
||||||
style_truncation = []
|
|
||||||
for style in styles:
|
|
||||||
style_truncation.append(
|
|
||||||
truncation_latent + truncation * (style - truncation_latent)
|
|
||||||
)
|
|
||||||
styles = style_truncation
|
|
||||||
# get style latent with injection
|
|
||||||
if len(styles) == 1:
|
|
||||||
inject_index = self.num_latent
|
|
||||||
|
|
||||||
if styles[0].ndim < 3:
|
|
||||||
# repeat latent code for all the layers
|
|
||||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
else: # used for encoder with different latent code for each layer
|
|
||||||
latent = styles[0]
|
|
||||||
elif len(styles) == 2: # mixing noises
|
|
||||||
if inject_index is None:
|
|
||||||
inject_index = random.randint(1, self.num_latent - 1)
|
|
||||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
latent2 = (
|
|
||||||
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
|
||||||
)
|
|
||||||
latent = torch.cat([latent1, latent2], 1)
|
|
||||||
|
|
||||||
# main generation
|
|
||||||
out = self.constant_input(latent.shape[0])
|
|
||||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
|
||||||
skip = self.to_rgb1(out, latent[:, 1])
|
|
||||||
|
|
||||||
i = 1
|
|
||||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
|
||||||
self.style_convs[::2],
|
|
||||||
self.style_convs[1::2],
|
|
||||||
noise[1::2],
|
|
||||||
noise[2::2],
|
|
||||||
self.to_rgbs,
|
|
||||||
):
|
|
||||||
out = conv1(out, latent[:, i], noise=noise1)
|
|
||||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
|
||||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
|
||||||
i += 2
|
|
||||||
|
|
||||||
image = skip
|
|
||||||
|
|
||||||
if return_latents:
|
|
||||||
return image, latent
|
|
||||||
else:
|
|
||||||
return image, None
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledLeakyReLU(nn.Module):
|
|
||||||
"""Scaled LeakyReLU.
|
|
||||||
Args:
|
|
||||||
negative_slope (float): Negative slope. Default: 0.2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, negative_slope=0.2):
|
|
||||||
super(ScaledLeakyReLU, self).__init__()
|
|
||||||
self.negative_slope = negative_slope
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
|
||||||
return out * math.sqrt(2)
|
|
||||||
|
|
||||||
|
|
||||||
class EqualConv2d(nn.Module):
|
|
||||||
"""Equalized Linear as StyleGAN2.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Size of the convolving kernel.
|
|
||||||
stride (int): Stride of the convolution. Default: 1
|
|
||||||
padding (int): Zero-padding added to both sides of the input.
|
|
||||||
Default: 0.
|
|
||||||
bias (bool): If ``True``, adds a learnable bias to the output.
|
|
||||||
Default: ``True``.
|
|
||||||
bias_init_val (float): Bias initialized value. Default: 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
bias=True,
|
|
||||||
bias_init_val=0,
|
|
||||||
):
|
|
||||||
super(EqualConv2d, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(
|
|
||||||
torch.randn(out_channels, in_channels, kernel_size, kernel_size)
|
|
||||||
)
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = F.conv2d(
|
|
||||||
x,
|
|
||||||
self.weight * self.scale,
|
|
||||||
bias=self.bias,
|
|
||||||
stride=self.stride,
|
|
||||||
padding=self.padding,
|
|
||||||
)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}(in_channels={self.in_channels}, "
|
|
||||||
f"out_channels={self.out_channels}, "
|
|
||||||
f"kernel_size={self.kernel_size},"
|
|
||||||
f" stride={self.stride}, padding={self.padding}, "
|
|
||||||
f"bias={self.bias is not None})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvLayer(nn.Sequential):
|
|
||||||
"""Conv Layer used in StyleGAN2 Discriminator.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Kernel size.
|
|
||||||
downsample (bool): Whether downsample by a factor of 2.
|
|
||||||
Default: False.
|
|
||||||
bias (bool): Whether with bias. Default: True.
|
|
||||||
activate (bool): Whether use activateion. Default: True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
downsample=False,
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
interpolation_mode="bilinear",
|
|
||||||
):
|
|
||||||
layers = []
|
|
||||||
self.interpolation_mode = interpolation_mode
|
|
||||||
# downsample
|
|
||||||
if downsample:
|
|
||||||
if self.interpolation_mode == "nearest":
|
|
||||||
self.align_corners = None
|
|
||||||
else:
|
|
||||||
self.align_corners = False
|
|
||||||
|
|
||||||
layers.append(
|
|
||||||
torch.nn.Upsample(
|
|
||||||
scale_factor=0.5,
|
|
||||||
mode=interpolation_mode,
|
|
||||||
align_corners=self.align_corners,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stride = 1
|
|
||||||
self.padding = kernel_size // 2
|
|
||||||
# conv
|
|
||||||
layers.append(
|
|
||||||
EqualConv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=self.padding,
|
|
||||||
bias=bias and not activate,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# activation
|
|
||||||
if activate:
|
|
||||||
if bias:
|
|
||||||
layers.append(FusedLeakyReLU(out_channels))
|
|
||||||
else:
|
|
||||||
layers.append(ScaledLeakyReLU(0.2))
|
|
||||||
|
|
||||||
super(ConvLayer, self).__init__(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
"""Residual block used in StyleGAN2 Discriminator.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, interpolation_mode="bilinear"):
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
|
|
||||||
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
|
||||||
self.conv2 = ConvLayer(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
3,
|
|
||||||
downsample=True,
|
|
||||||
interpolation_mode=interpolation_mode,
|
|
||||||
bias=True,
|
|
||||||
activate=True,
|
|
||||||
)
|
|
||||||
self.skip = ConvLayer(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
1,
|
|
||||||
downsample=True,
|
|
||||||
interpolation_mode=interpolation_mode,
|
|
||||||
bias=False,
|
|
||||||
activate=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.conv2(out)
|
|
||||||
skip = self.skip(x)
|
|
||||||
out = (out + skip) / math.sqrt(2)
|
|
||||||
return out
|
|
||||||
@ -1,453 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# type: ignore
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.nn import init
|
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
|
||||||
"""Initialize network weights.
|
|
||||||
Args:
|
|
||||||
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
|
||||||
scale (float): Scale initialized weights, especially for residual
|
|
||||||
blocks. Default: 1.
|
|
||||||
bias_fill (float): The value to fill bias. Default: 0
|
|
||||||
kwargs (dict): Other arguments for initialization function.
|
|
||||||
"""
|
|
||||||
if not isinstance(module_list, list):
|
|
||||||
module_list = [module_list]
|
|
||||||
for module in module_list:
|
|
||||||
for m in module.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
init.kaiming_normal_(m.weight, **kwargs)
|
|
||||||
m.weight.data *= scale
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.fill_(bias_fill)
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.kaiming_normal_(m.weight, **kwargs)
|
|
||||||
m.weight.data *= scale
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.fill_(bias_fill)
|
|
||||||
elif isinstance(m, _BatchNorm):
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.fill_(bias_fill)
|
|
||||||
|
|
||||||
|
|
||||||
class NormStyleCode(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
"""Normalize the style codes.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Style codes with shape (b, c).
|
|
||||||
Returns:
|
|
||||||
Tensor: Normalized tensor.
|
|
||||||
"""
|
|
||||||
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
|
||||||
|
|
||||||
|
|
||||||
class ModulatedConv2d(nn.Module):
|
|
||||||
"""Modulated Conv2d used in StyleGAN2.
|
|
||||||
There is no bias in ModulatedConv2d.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Size of the convolving kernel.
|
|
||||||
num_style_feat (int): Channel number of style features.
|
|
||||||
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
|
||||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
|
||||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
eps=1e-8,
|
|
||||||
):
|
|
||||||
super(ModulatedConv2d, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.demodulate = demodulate
|
|
||||||
self.sample_mode = sample_mode
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
# modulation inside each modulated conv
|
|
||||||
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
|
||||||
# initialization
|
|
||||||
default_init_weights(
|
|
||||||
self.modulation,
|
|
||||||
scale=1,
|
|
||||||
bias_fill=1,
|
|
||||||
a=0,
|
|
||||||
mode="fan_in",
|
|
||||||
nonlinearity="linear",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(
|
|
||||||
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
|
|
||||||
/ math.sqrt(in_channels * kernel_size**2)
|
|
||||||
)
|
|
||||||
self.padding = kernel_size // 2
|
|
||||||
|
|
||||||
def forward(self, x, style):
|
|
||||||
"""Forward function.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Tensor with shape (b, c, h, w).
|
|
||||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
|
||||||
Returns:
|
|
||||||
Tensor: Modulated tensor after convolution.
|
|
||||||
"""
|
|
||||||
b, c, h, w = x.shape # c = c_in
|
|
||||||
# weight modulation
|
|
||||||
style = self.modulation(style).view(b, 1, c, 1, 1)
|
|
||||||
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
|
||||||
weight = self.weight * style # (b, c_out, c_in, k, k)
|
|
||||||
|
|
||||||
if self.demodulate:
|
|
||||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
|
||||||
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
|
||||||
|
|
||||||
weight = weight.view(
|
|
||||||
b * self.out_channels, c, self.kernel_size, self.kernel_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# upsample or downsample if necessary
|
|
||||||
if self.sample_mode == "upsample":
|
|
||||||
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
|
|
||||||
elif self.sample_mode == "downsample":
|
|
||||||
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
|
|
||||||
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
x = x.view(1, b * c, h, w)
|
|
||||||
# weight: (b*c_out, c_in, k, k), groups=b
|
|
||||||
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
|
||||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
|
|
||||||
f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StyleConv(nn.Module):
|
|
||||||
"""Style conv used in StyleGAN2.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of the input.
|
|
||||||
out_channels (int): Channel number of the output.
|
|
||||||
kernel_size (int): Size of the convolving kernel.
|
|
||||||
num_style_feat (int): Channel number of style features.
|
|
||||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
|
||||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
):
|
|
||||||
super(StyleConv, self).__init__()
|
|
||||||
self.modulated_conv = ModulatedConv2d(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size,
|
|
||||||
num_style_feat,
|
|
||||||
demodulate=demodulate,
|
|
||||||
sample_mode=sample_mode,
|
|
||||||
)
|
|
||||||
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
|
||||||
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
|
||||||
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x, style, noise=None):
|
|
||||||
# modulate
|
|
||||||
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
|
||||||
# noise injection
|
|
||||||
if noise is None:
|
|
||||||
b, _, h, w = out.shape
|
|
||||||
noise = out.new_empty(b, 1, h, w).normal_()
|
|
||||||
out = out + self.weight * noise
|
|
||||||
# add bias
|
|
||||||
out = out + self.bias
|
|
||||||
# activation
|
|
||||||
out = self.activate(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ToRGB(nn.Module):
|
|
||||||
"""To RGB (image space) from features.
|
|
||||||
Args:
|
|
||||||
in_channels (int): Channel number of input.
|
|
||||||
num_style_feat (int): Channel number of style features.
|
|
||||||
upsample (bool): Whether to upsample. Default: True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, num_style_feat, upsample=True):
|
|
||||||
super(ToRGB, self).__init__()
|
|
||||||
self.upsample = upsample
|
|
||||||
self.modulated_conv = ModulatedConv2d(
|
|
||||||
in_channels,
|
|
||||||
3,
|
|
||||||
kernel_size=1,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=False,
|
|
||||||
sample_mode=None,
|
|
||||||
)
|
|
||||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
|
||||||
|
|
||||||
def forward(self, x, style, skip=None):
|
|
||||||
"""Forward function.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Feature tensor with shape (b, c, h, w).
|
|
||||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
|
||||||
skip (Tensor): Base/skip tensor. Default: None.
|
|
||||||
Returns:
|
|
||||||
Tensor: RGB images.
|
|
||||||
"""
|
|
||||||
out = self.modulated_conv(x, style)
|
|
||||||
out = out + self.bias
|
|
||||||
if skip is not None:
|
|
||||||
if self.upsample:
|
|
||||||
skip = F.interpolate(
|
|
||||||
skip, scale_factor=2, mode="bilinear", align_corners=False
|
|
||||||
)
|
|
||||||
out = out + skip
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ConstantInput(nn.Module):
|
|
||||||
"""Constant input.
|
|
||||||
Args:
|
|
||||||
num_channel (int): Channel number of constant input.
|
|
||||||
size (int): Spatial size of constant input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_channel, size):
|
|
||||||
super(ConstantInput, self).__init__()
|
|
||||||
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
|
||||||
|
|
||||||
def forward(self, batch):
|
|
||||||
out = self.weight.repeat(batch, 1, 1, 1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class StyleGAN2GeneratorClean(nn.Module):
|
|
||||||
"""Clean version of StyleGAN2 Generator.
|
|
||||||
Args:
|
|
||||||
out_size (int): The spatial size of outputs.
|
|
||||||
num_style_feat (int): Channel number of style features. Default: 512.
|
|
||||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
|
||||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
|
||||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
|
|
||||||
):
|
|
||||||
super(StyleGAN2GeneratorClean, self).__init__()
|
|
||||||
# Style MLP layers
|
|
||||||
self.num_style_feat = num_style_feat
|
|
||||||
style_mlp_layers = [NormStyleCode()]
|
|
||||||
for i in range(num_mlp):
|
|
||||||
style_mlp_layers.extend(
|
|
||||||
[
|
|
||||||
nn.Linear(num_style_feat, num_style_feat, bias=True),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
|
||||||
# initialization
|
|
||||||
default_init_weights(
|
|
||||||
self.style_mlp,
|
|
||||||
scale=1,
|
|
||||||
bias_fill=0,
|
|
||||||
a=0.2,
|
|
||||||
mode="fan_in",
|
|
||||||
nonlinearity="leaky_relu",
|
|
||||||
)
|
|
||||||
|
|
||||||
# channel list
|
|
||||||
channels = {
|
|
||||||
"4": int(512 * narrow),
|
|
||||||
"8": int(512 * narrow),
|
|
||||||
"16": int(512 * narrow),
|
|
||||||
"32": int(512 * narrow),
|
|
||||||
"64": int(256 * channel_multiplier * narrow),
|
|
||||||
"128": int(128 * channel_multiplier * narrow),
|
|
||||||
"256": int(64 * channel_multiplier * narrow),
|
|
||||||
"512": int(32 * channel_multiplier * narrow),
|
|
||||||
"1024": int(16 * channel_multiplier * narrow),
|
|
||||||
}
|
|
||||||
self.channels = channels
|
|
||||||
|
|
||||||
self.constant_input = ConstantInput(channels["4"], size=4)
|
|
||||||
self.style_conv1 = StyleConv(
|
|
||||||
channels["4"],
|
|
||||||
channels["4"],
|
|
||||||
kernel_size=3,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
)
|
|
||||||
self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
|
|
||||||
|
|
||||||
self.log_size = int(math.log(out_size, 2))
|
|
||||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
|
||||||
self.num_latent = self.log_size * 2 - 2
|
|
||||||
|
|
||||||
self.style_convs = nn.ModuleList()
|
|
||||||
self.to_rgbs = nn.ModuleList()
|
|
||||||
self.noises = nn.Module()
|
|
||||||
|
|
||||||
in_channels = channels["4"]
|
|
||||||
# noise
|
|
||||||
for layer_idx in range(self.num_layers):
|
|
||||||
resolution = 2 ** ((layer_idx + 5) // 2)
|
|
||||||
shape = [1, 1, resolution, resolution]
|
|
||||||
self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
|
|
||||||
# style convs and to_rgbs
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
out_channels = channels[f"{2**i}"]
|
|
||||||
self.style_convs.append(
|
|
||||||
StyleConv(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode="upsample",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.style_convs.append(
|
|
||||||
StyleConv(
|
|
||||||
out_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
num_style_feat=num_style_feat,
|
|
||||||
demodulate=True,
|
|
||||||
sample_mode=None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
def make_noise(self):
|
|
||||||
"""Make noise for noise injection."""
|
|
||||||
device = self.constant_input.weight.device
|
|
||||||
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
|
||||||
|
|
||||||
for i in range(3, self.log_size + 1):
|
|
||||||
for _ in range(2):
|
|
||||||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
|
||||||
|
|
||||||
return noises
|
|
||||||
|
|
||||||
def get_latent(self, x):
|
|
||||||
return self.style_mlp(x)
|
|
||||||
|
|
||||||
def mean_latent(self, num_latent):
|
|
||||||
latent_in = torch.randn(
|
|
||||||
num_latent, self.num_style_feat, device=self.constant_input.weight.device
|
|
||||||
)
|
|
||||||
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
|
||||||
return latent
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
styles,
|
|
||||||
input_is_latent=False,
|
|
||||||
noise=None,
|
|
||||||
randomize_noise=True,
|
|
||||||
truncation=1,
|
|
||||||
truncation_latent=None,
|
|
||||||
inject_index=None,
|
|
||||||
return_latents=False,
|
|
||||||
):
|
|
||||||
"""Forward function for StyleGAN2GeneratorClean.
|
|
||||||
Args:
|
|
||||||
styles (list[Tensor]): Sample codes of styles.
|
|
||||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
|
||||||
noise (Tensor | None): Input noise or None. Default: None.
|
|
||||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
|
||||||
truncation (float): The truncation ratio. Default: 1.
|
|
||||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
|
||||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
|
||||||
return_latents (bool): Whether to return style latents. Default: False.
|
|
||||||
"""
|
|
||||||
# style codes -> latents with Style MLP layer
|
|
||||||
if not input_is_latent:
|
|
||||||
styles = [self.style_mlp(s) for s in styles]
|
|
||||||
# noises
|
|
||||||
if noise is None:
|
|
||||||
if randomize_noise:
|
|
||||||
noise = [None] * self.num_layers # for each style conv layer
|
|
||||||
else: # use the stored noise
|
|
||||||
noise = [
|
|
||||||
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
|
|
||||||
]
|
|
||||||
# style truncation
|
|
||||||
if truncation < 1:
|
|
||||||
style_truncation = []
|
|
||||||
for style in styles:
|
|
||||||
style_truncation.append(
|
|
||||||
truncation_latent + truncation * (style - truncation_latent)
|
|
||||||
)
|
|
||||||
styles = style_truncation
|
|
||||||
# get style latents with injection
|
|
||||||
if len(styles) == 1:
|
|
||||||
inject_index = self.num_latent
|
|
||||||
|
|
||||||
if styles[0].ndim < 3:
|
|
||||||
# repeat latent code for all the layers
|
|
||||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
else: # used for encoder with different latent code for each layer
|
|
||||||
latent = styles[0]
|
|
||||||
elif len(styles) == 2: # mixing noises
|
|
||||||
if inject_index is None:
|
|
||||||
inject_index = random.randint(1, self.num_latent - 1)
|
|
||||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
|
||||||
latent2 = (
|
|
||||||
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
|
||||||
)
|
|
||||||
latent = torch.cat([latent1, latent2], 1)
|
|
||||||
|
|
||||||
# main generation
|
|
||||||
out = self.constant_input(latent.shape[0])
|
|
||||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
|
||||||
skip = self.to_rgb1(out, latent[:, 1])
|
|
||||||
|
|
||||||
i = 1
|
|
||||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
|
||||||
self.style_convs[::2],
|
|
||||||
self.style_convs[1::2],
|
|
||||||
noise[1::2],
|
|
||||||
noise[2::2],
|
|
||||||
self.to_rgbs,
|
|
||||||
):
|
|
||||||
out = conv1(out, latent[:, i], noise=noise1)
|
|
||||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
|
||||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
|
||||||
i += 2
|
|
||||||
|
|
||||||
image = skip
|
|
||||||
|
|
||||||
if return_latents:
|
|
||||||
return image, latent
|
|
||||||
else:
|
|
||||||
return image, None
|
|
||||||
@ -1,194 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# type: ignore
|
|
||||||
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.autograd import Function
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
upfirdn2d_ext = None
|
|
||||||
|
|
||||||
|
|
||||||
class UpFirDn2dBackward(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
|
||||||
):
|
|
||||||
up_x, up_y = up
|
|
||||||
down_x, down_y = down
|
|
||||||
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
|
||||||
|
|
||||||
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
|
||||||
|
|
||||||
grad_input = upfirdn2d_ext.upfirdn2d(
|
|
||||||
grad_output,
|
|
||||||
grad_kernel,
|
|
||||||
down_x,
|
|
||||||
down_y,
|
|
||||||
up_x,
|
|
||||||
up_y,
|
|
||||||
g_pad_x0,
|
|
||||||
g_pad_x1,
|
|
||||||
g_pad_y0,
|
|
||||||
g_pad_y1,
|
|
||||||
)
|
|
||||||
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
|
||||||
|
|
||||||
ctx.save_for_backward(kernel)
|
|
||||||
|
|
||||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
|
||||||
|
|
||||||
ctx.up_x = up_x
|
|
||||||
ctx.up_y = up_y
|
|
||||||
ctx.down_x = down_x
|
|
||||||
ctx.down_y = down_y
|
|
||||||
ctx.pad_x0 = pad_x0
|
|
||||||
ctx.pad_x1 = pad_x1
|
|
||||||
ctx.pad_y0 = pad_y0
|
|
||||||
ctx.pad_y1 = pad_y1
|
|
||||||
ctx.in_size = in_size
|
|
||||||
ctx.out_size = out_size
|
|
||||||
|
|
||||||
return grad_input
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, gradgrad_input):
|
|
||||||
(kernel,) = ctx.saved_tensors
|
|
||||||
|
|
||||||
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
|
||||||
|
|
||||||
gradgrad_out = upfirdn2d_ext.upfirdn2d(
|
|
||||||
gradgrad_input,
|
|
||||||
kernel,
|
|
||||||
ctx.up_x,
|
|
||||||
ctx.up_y,
|
|
||||||
ctx.down_x,
|
|
||||||
ctx.down_y,
|
|
||||||
ctx.pad_x0,
|
|
||||||
ctx.pad_x1,
|
|
||||||
ctx.pad_y0,
|
|
||||||
ctx.pad_y1,
|
|
||||||
)
|
|
||||||
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
|
|
||||||
# ctx.out_size[1], ctx.in_size[3])
|
|
||||||
gradgrad_out = gradgrad_out.view(
|
|
||||||
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
return gradgrad_out, None, None, None, None, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class UpFirDn2d(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input, kernel, up, down, pad):
|
|
||||||
up_x, up_y = up
|
|
||||||
down_x, down_y = down
|
|
||||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
|
||||||
|
|
||||||
kernel_h, kernel_w = kernel.shape
|
|
||||||
_, channel, in_h, in_w = input.shape
|
|
||||||
ctx.in_size = input.shape
|
|
||||||
|
|
||||||
input = input.reshape(-1, in_h, in_w, 1)
|
|
||||||
|
|
||||||
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
|
||||||
|
|
||||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
|
||||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
|
||||||
ctx.out_size = (out_h, out_w)
|
|
||||||
|
|
||||||
ctx.up = (up_x, up_y)
|
|
||||||
ctx.down = (down_x, down_y)
|
|
||||||
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
|
||||||
|
|
||||||
g_pad_x0 = kernel_w - pad_x0 - 1
|
|
||||||
g_pad_y0 = kernel_h - pad_y0 - 1
|
|
||||||
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
|
||||||
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
|
||||||
|
|
||||||
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
|
||||||
|
|
||||||
out = upfirdn2d_ext.upfirdn2d(
|
|
||||||
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
|
||||||
)
|
|
||||||
# out = out.view(major, out_h, out_w, minor)
|
|
||||||
out = out.view(-1, channel, out_h, out_w)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
kernel, grad_kernel = ctx.saved_tensors
|
|
||||||
|
|
||||||
grad_input = UpFirDn2dBackward.apply(
|
|
||||||
grad_output,
|
|
||||||
kernel,
|
|
||||||
grad_kernel,
|
|
||||||
ctx.up,
|
|
||||||
ctx.down,
|
|
||||||
ctx.pad,
|
|
||||||
ctx.g_pad,
|
|
||||||
ctx.in_size,
|
|
||||||
ctx.out_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_input, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
|
||||||
if input.device.type == "cpu":
|
|
||||||
out = upfirdn2d_native(
|
|
||||||
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out = UpFirDn2d.apply(
|
|
||||||
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
|
||||||
)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def upfirdn2d_native(
|
|
||||||
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
|
||||||
):
|
|
||||||
_, channel, in_h, in_w = input.shape
|
|
||||||
input = input.reshape(-1, in_h, in_w, 1)
|
|
||||||
|
|
||||||
_, in_h, in_w, minor = input.shape
|
|
||||||
kernel_h, kernel_w = kernel.shape
|
|
||||||
|
|
||||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
|
||||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
|
||||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
|
||||||
|
|
||||||
out = F.pad(
|
|
||||||
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
|
||||||
)
|
|
||||||
out = out[
|
|
||||||
:,
|
|
||||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
|
||||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
|
||||||
:,
|
|
||||||
]
|
|
||||||
|
|
||||||
out = out.permute(0, 3, 1, 2)
|
|
||||||
out = out.reshape(
|
|
||||||
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
|
||||||
)
|
|
||||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
|
||||||
out = F.conv2d(out, w)
|
|
||||||
out = out.reshape(
|
|
||||||
-1,
|
|
||||||
minor,
|
|
||||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
|
||||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
|
||||||
)
|
|
||||||
out = out.permute(0, 2, 3, 1)
|
|
||||||
out = out[:, ::down_y, ::down_x, :]
|
|
||||||
|
|
||||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
|
||||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
|
||||||
|
|
||||||
return out.view(-1, channel, out_h, out_w)
|
|
||||||
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright 2019 Ross Wightman
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@ -1,223 +0,0 @@
|
|||||||
""" DropBlock, DropPath
|
|
||||||
|
|
||||||
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
|
||||||
|
|
||||||
Papers:
|
|
||||||
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
|
|
||||||
|
|
||||||
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
|
||||||
|
|
||||||
Code:
|
|
||||||
DropBlock impl inspired by two Tensorflow impl that I liked:
|
|
||||||
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
|
|
||||||
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
|
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def drop_block_2d(
|
|
||||||
x,
|
|
||||||
drop_prob: float = 0.1,
|
|
||||||
block_size: int = 7,
|
|
||||||
gamma_scale: float = 1.0,
|
|
||||||
with_noise: bool = False,
|
|
||||||
inplace: bool = False,
|
|
||||||
batchwise: bool = False,
|
|
||||||
):
|
|
||||||
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
|
||||||
|
|
||||||
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
|
||||||
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
|
||||||
"""
|
|
||||||
_, C, H, W = x.shape
|
|
||||||
total_size = W * H
|
|
||||||
clipped_block_size = min(block_size, min(W, H))
|
|
||||||
# seed_drop_rate, the gamma parameter
|
|
||||||
gamma = (
|
|
||||||
gamma_scale
|
|
||||||
* drop_prob
|
|
||||||
* total_size
|
|
||||||
/ clipped_block_size**2
|
|
||||||
/ ((W - block_size + 1) * (H - block_size + 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Forces the block to be inside the feature map.
|
|
||||||
w_i, h_i = torch.meshgrid(
|
|
||||||
torch.arange(W).to(x.device), torch.arange(H).to(x.device)
|
|
||||||
)
|
|
||||||
valid_block = (
|
|
||||||
(w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)
|
|
||||||
) & ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
|
|
||||||
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
if batchwise:
|
|
||||||
# one mask for whole batch, quite a bit faster
|
|
||||||
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
|
|
||||||
else:
|
|
||||||
uniform_noise = torch.rand_like(x)
|
|
||||||
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
|
||||||
block_mask = -F.max_pool2d(
|
|
||||||
-block_mask,
|
|
||||||
kernel_size=clipped_block_size, # block_size,
|
|
||||||
stride=1,
|
|
||||||
padding=clipped_block_size // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if with_noise:
|
|
||||||
normal_noise = (
|
|
||||||
torch.randn((1, C, H, W), dtype=x.dtype, device=x.device)
|
|
||||||
if batchwise
|
|
||||||
else torch.randn_like(x)
|
|
||||||
)
|
|
||||||
if inplace:
|
|
||||||
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
|
||||||
else:
|
|
||||||
x = x * block_mask + normal_noise * (1 - block_mask)
|
|
||||||
else:
|
|
||||||
normalize_scale = (
|
|
||||||
block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)
|
|
||||||
).to(x.dtype)
|
|
||||||
if inplace:
|
|
||||||
x.mul_(block_mask * normalize_scale)
|
|
||||||
else:
|
|
||||||
x = x * block_mask * normalize_scale
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def drop_block_fast_2d(
|
|
||||||
x: torch.Tensor,
|
|
||||||
drop_prob: float = 0.1,
|
|
||||||
block_size: int = 7,
|
|
||||||
gamma_scale: float = 1.0,
|
|
||||||
with_noise: bool = False,
|
|
||||||
inplace: bool = False,
|
|
||||||
):
|
|
||||||
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
|
||||||
|
|
||||||
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
|
||||||
block mask at edges.
|
|
||||||
"""
|
|
||||||
_, _, H, W = x.shape
|
|
||||||
total_size = W * H
|
|
||||||
clipped_block_size = min(block_size, min(W, H))
|
|
||||||
gamma = (
|
|
||||||
gamma_scale
|
|
||||||
* drop_prob
|
|
||||||
* total_size
|
|
||||||
/ clipped_block_size**2
|
|
||||||
/ ((W - block_size + 1) * (H - block_size + 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
block_mask = torch.empty_like(x).bernoulli_(gamma)
|
|
||||||
block_mask = F.max_pool2d(
|
|
||||||
block_mask.to(x.dtype),
|
|
||||||
kernel_size=clipped_block_size,
|
|
||||||
stride=1,
|
|
||||||
padding=clipped_block_size // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if with_noise:
|
|
||||||
normal_noise = torch.empty_like(x).normal_()
|
|
||||||
if inplace:
|
|
||||||
x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
|
|
||||||
else:
|
|
||||||
x = x * (1.0 - block_mask) + normal_noise * block_mask
|
|
||||||
else:
|
|
||||||
block_mask = 1 - block_mask
|
|
||||||
normalize_scale = (
|
|
||||||
block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)
|
|
||||||
).to(dtype=x.dtype)
|
|
||||||
if inplace:
|
|
||||||
x.mul_(block_mask * normalize_scale)
|
|
||||||
else:
|
|
||||||
x = x * block_mask * normalize_scale
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DropBlock2d(nn.Module):
|
|
||||||
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
drop_prob: float = 0.1,
|
|
||||||
block_size: int = 7,
|
|
||||||
gamma_scale: float = 1.0,
|
|
||||||
with_noise: bool = False,
|
|
||||||
inplace: bool = False,
|
|
||||||
batchwise: bool = False,
|
|
||||||
fast: bool = True,
|
|
||||||
):
|
|
||||||
super(DropBlock2d, self).__init__()
|
|
||||||
self.drop_prob = drop_prob
|
|
||||||
self.gamma_scale = gamma_scale
|
|
||||||
self.block_size = block_size
|
|
||||||
self.with_noise = with_noise
|
|
||||||
self.inplace = inplace
|
|
||||||
self.batchwise = batchwise
|
|
||||||
self.fast = fast # FIXME finish comparisons of fast vs not
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if not self.training or not self.drop_prob:
|
|
||||||
return x
|
|
||||||
if self.fast:
|
|
||||||
return drop_block_fast_2d(
|
|
||||||
x,
|
|
||||||
self.drop_prob,
|
|
||||||
self.block_size,
|
|
||||||
self.gamma_scale,
|
|
||||||
self.with_noise,
|
|
||||||
self.inplace,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return drop_block_2d(
|
|
||||||
x,
|
|
||||||
self.drop_prob,
|
|
||||||
self.block_size,
|
|
||||||
self.gamma_scale,
|
|
||||||
self.with_noise,
|
|
||||||
self.inplace,
|
|
||||||
self.batchwise,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def drop_path(
|
|
||||||
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
||||||
):
|
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
||||||
|
|
||||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
|
||||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
|
||||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
|
||||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
|
||||||
'survival rate' as the argument.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if drop_prob == 0.0 or not training:
|
|
||||||
return x
|
|
||||||
keep_prob = 1 - drop_prob
|
|
||||||
shape = (x.shape[0],) + (1,) * (
|
|
||||||
x.ndim - 1
|
|
||||||
) # work with diff dim tensors, not just 2D ConvNets
|
|
||||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
||||||
if keep_prob > 0.0 and scale_by_keep:
|
|
||||||
random_tensor.div_(keep_prob)
|
|
||||||
return x * random_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class DropPath(nn.Module):
|
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
|
||||||
|
|
||||||
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
|
||||||
super(DropPath, self).__init__()
|
|
||||||
self.drop_prob = drop_prob
|
|
||||||
self.scale_by_keep = scale_by_keep
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
|
||||||
@ -1,31 +0,0 @@
|
|||||||
""" Layer/Module Helpers
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
|
||||||
"""
|
|
||||||
import collections.abc
|
|
||||||
from itertools import repeat
|
|
||||||
|
|
||||||
|
|
||||||
# From PyTorch internals
|
|
||||||
def _ntuple(n):
|
|
||||||
def parse(x):
|
|
||||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
|
||||||
return x
|
|
||||||
return tuple(repeat(x, n))
|
|
||||||
|
|
||||||
return parse
|
|
||||||
|
|
||||||
|
|
||||||
to_1tuple = _ntuple(1)
|
|
||||||
to_2tuple = _ntuple(2)
|
|
||||||
to_3tuple = _ntuple(3)
|
|
||||||
to_4tuple = _ntuple(4)
|
|
||||||
to_ntuple = _ntuple
|
|
||||||
|
|
||||||
|
|
||||||
def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
|
|
||||||
min_value = min_value or divisor
|
|
||||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
|
||||||
# Make sure that round down does not go down by more than 10%.
|
|
||||||
if new_v < round_limit * v:
|
|
||||||
new_v += divisor
|
|
||||||
return new_v
|
|
||||||
@ -1,128 +0,0 @@
|
|||||||
import math
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
|
||||||
|
|
||||||
|
|
||||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
|
||||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
|
||||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
||||||
def norm_cdf(x):
|
|
||||||
# Computes standard normal cumulative distribution function
|
|
||||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
|
||||||
|
|
||||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
||||||
warnings.warn(
|
|
||||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
|
||||||
"The distribution of values may be incorrect.",
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# Values are generated by using a truncated uniform distribution and
|
|
||||||
# then using the inverse CDF for the normal distribution.
|
|
||||||
# Get upper and lower cdf values
|
|
||||||
l = norm_cdf((a - mean) / std)
|
|
||||||
u = norm_cdf((b - mean) / std)
|
|
||||||
|
|
||||||
# Uniformly fill tensor with values from [l, u], then translate to
|
|
||||||
# [2l-1, 2u-1].
|
|
||||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
|
||||||
|
|
||||||
# Use inverse cdf transform for normal distribution to get truncated
|
|
||||||
# standard normal
|
|
||||||
tensor.erfinv_()
|
|
||||||
|
|
||||||
# Transform to proper mean, std
|
|
||||||
tensor.mul_(std * math.sqrt(2.0))
|
|
||||||
tensor.add_(mean)
|
|
||||||
|
|
||||||
# Clamp to ensure it's in the proper range
|
|
||||||
tensor.clamp_(min=a, max=b)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def trunc_normal_(
|
|
||||||
tensor: torch.Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""Fills the input Tensor with values drawn from a truncated
|
|
||||||
normal distribution. The values are effectively drawn from the
|
|
||||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
||||||
with values outside :math:`[a, b]` redrawn until they are within
|
|
||||||
the bounds. The method used for generating the random values works
|
|
||||||
best when :math:`a \leq \text{mean} \leq b`.
|
|
||||||
|
|
||||||
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
|
|
||||||
applied while sampling the normal with mean/std applied, therefore a, b args
|
|
||||||
should be adjusted to match the range of mean, std args.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: an n-dimensional `torch.Tensor`
|
|
||||||
mean: the mean of the normal distribution
|
|
||||||
std: the standard deviation of the normal distribution
|
|
||||||
a: the minimum cutoff value
|
|
||||||
b: the maximum cutoff value
|
|
||||||
Examples:
|
|
||||||
>>> w = torch.empty(3, 5)
|
|
||||||
>>> nn.init.trunc_normal_(w)
|
|
||||||
"""
|
|
||||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
||||||
|
|
||||||
|
|
||||||
def trunc_normal_tf_(
|
|
||||||
tensor: torch.Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""Fills the input Tensor with values drawn from a truncated
|
|
||||||
normal distribution. The values are effectively drawn from the
|
|
||||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
||||||
with values outside :math:`[a, b]` redrawn until they are within
|
|
||||||
the bounds. The method used for generating the random values works
|
|
||||||
best when :math:`a \leq \text{mean} \leq b`.
|
|
||||||
|
|
||||||
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
|
||||||
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
|
||||||
and the result is subsquently scaled and shifted by the mean and std args.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: an n-dimensional `torch.Tensor`
|
|
||||||
mean: the mean of the normal distribution
|
|
||||||
std: the standard deviation of the normal distribution
|
|
||||||
a: the minimum cutoff value
|
|
||||||
b: the maximum cutoff value
|
|
||||||
Examples:
|
|
||||||
>>> w = torch.empty(3, 5)
|
|
||||||
>>> nn.init.trunc_normal_(w)
|
|
||||||
"""
|
|
||||||
_no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
|
|
||||||
with torch.no_grad():
|
|
||||||
tensor.mul_(std).add_(mean)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
|
||||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
|
||||||
if mode == "fan_in":
|
|
||||||
denom = fan_in
|
|
||||||
elif mode == "fan_out":
|
|
||||||
denom = fan_out
|
|
||||||
elif mode == "fan_avg":
|
|
||||||
denom = (fan_in + fan_out) / 2
|
|
||||||
|
|
||||||
variance = scale / denom # type: ignore
|
|
||||||
|
|
||||||
if distribution == "truncated_normal":
|
|
||||||
# constant is stddev of standard normal truncated to (-2, 2)
|
|
||||||
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
|
||||||
elif distribution == "normal":
|
|
||||||
tensor.normal_(std=math.sqrt(variance))
|
|
||||||
elif distribution == "uniform":
|
|
||||||
bound = math.sqrt(3 * variance)
|
|
||||||
# pylint: disable=invalid-unary-operand-type
|
|
||||||
tensor.uniform_(-bound, bound)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"invalid distribution {distribution}")
|
|
||||||
|
|
||||||
|
|
||||||
def lecun_normal_(tensor):
|
|
||||||
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
|
||||||
@ -1,99 +1,5 @@
|
|||||||
import logging as logger
|
from spandrel import ModelLoader
|
||||||
|
|
||||||
from .architecture.DAT import DAT
|
def load_state_dict(state_dict):
|
||||||
from .architecture.face.codeformer import CodeFormer
|
print("WARNING: comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
|
||||||
from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
|
return ModelLoader().load_from_state_dict(state_dict).eval()
|
||||||
from .architecture.face.restoreformer_arch import RestoreFormer
|
|
||||||
from .architecture.HAT import HAT
|
|
||||||
from .architecture.LaMa import LaMa
|
|
||||||
from .architecture.OmniSR.OmniSR import OmniSR
|
|
||||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
|
||||||
from .architecture.SCUNet import SCUNet
|
|
||||||
from .architecture.SPSR import SPSRNet as SPSR
|
|
||||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
|
||||||
from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
|
|
||||||
from .architecture.Swin2SR import Swin2SR
|
|
||||||
from .architecture.SwinIR import SwinIR
|
|
||||||
from .types import PyTorchModel
|
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedModel(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(state_dict) -> PyTorchModel:
|
|
||||||
logger.debug(f"Loading state dict into pytorch model arch")
|
|
||||||
|
|
||||||
state_dict_keys = list(state_dict.keys())
|
|
||||||
|
|
||||||
if "params_ema" in state_dict_keys:
|
|
||||||
state_dict = state_dict["params_ema"]
|
|
||||||
elif "params-ema" in state_dict_keys:
|
|
||||||
state_dict = state_dict["params-ema"]
|
|
||||||
elif "params" in state_dict_keys:
|
|
||||||
state_dict = state_dict["params"]
|
|
||||||
|
|
||||||
state_dict_keys = list(state_dict.keys())
|
|
||||||
# SRVGGNet Real-ESRGAN (v2)
|
|
||||||
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
|
||||||
model = RealESRGANv2(state_dict)
|
|
||||||
# SPSR (ESRGAN with lots of extra layers)
|
|
||||||
elif "f_HR_conv1.0.weight" in state_dict:
|
|
||||||
model = SPSR(state_dict)
|
|
||||||
# Swift-SRGAN
|
|
||||||
elif (
|
|
||||||
"model" in state_dict_keys
|
|
||||||
and "initial.cnn.depthwise.weight" in state_dict["model"].keys()
|
|
||||||
):
|
|
||||||
model = SwiftSRGAN(state_dict)
|
|
||||||
# SwinIR, Swin2SR, HAT
|
|
||||||
elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys:
|
|
||||||
if (
|
|
||||||
"layers.0.residual_group.blocks.0.conv_block.cab.0.weight"
|
|
||||||
in state_dict_keys
|
|
||||||
):
|
|
||||||
model = HAT(state_dict)
|
|
||||||
elif "patch_embed.proj.weight" in state_dict_keys:
|
|
||||||
model = Swin2SR(state_dict)
|
|
||||||
else:
|
|
||||||
model = SwinIR(state_dict)
|
|
||||||
# GFPGAN
|
|
||||||
elif (
|
|
||||||
"toRGB.0.weight" in state_dict_keys
|
|
||||||
and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
|
|
||||||
):
|
|
||||||
model = GFPGANv1Clean(state_dict)
|
|
||||||
# RestoreFormer
|
|
||||||
elif (
|
|
||||||
"encoder.conv_in.weight" in state_dict_keys
|
|
||||||
and "encoder.down.0.block.0.norm1.weight" in state_dict_keys
|
|
||||||
):
|
|
||||||
model = RestoreFormer(state_dict)
|
|
||||||
elif (
|
|
||||||
"encoder.blocks.0.weight" in state_dict_keys
|
|
||||||
and "quantize.embedding.weight" in state_dict_keys
|
|
||||||
):
|
|
||||||
model = CodeFormer(state_dict)
|
|
||||||
# LaMa
|
|
||||||
elif (
|
|
||||||
"model.model.1.bn_l.running_mean" in state_dict_keys
|
|
||||||
or "generator.model.1.bn_l.running_mean" in state_dict_keys
|
|
||||||
):
|
|
||||||
model = LaMa(state_dict)
|
|
||||||
# Omni-SR
|
|
||||||
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
|
|
||||||
model = OmniSR(state_dict)
|
|
||||||
# SCUNet
|
|
||||||
elif "m_head.0.weight" in state_dict_keys and "m_tail.0.weight" in state_dict_keys:
|
|
||||||
model = SCUNet(state_dict)
|
|
||||||
# DAT
|
|
||||||
elif "layers.0.blocks.2.attn.attn_mask_0" in state_dict_keys:
|
|
||||||
model = DAT(state_dict)
|
|
||||||
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
model = ESRGAN(state_dict)
|
|
||||||
except:
|
|
||||||
# pylint: disable=raise-missing-from
|
|
||||||
raise UnsupportedModel
|
|
||||||
return model
|
|
||||||
|
|||||||
@ -1,69 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
|
|
||||||
from .architecture.DAT import DAT
|
|
||||||
from .architecture.face.codeformer import CodeFormer
|
|
||||||
from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
|
|
||||||
from .architecture.face.restoreformer_arch import RestoreFormer
|
|
||||||
from .architecture.HAT import HAT
|
|
||||||
from .architecture.LaMa import LaMa
|
|
||||||
from .architecture.OmniSR.OmniSR import OmniSR
|
|
||||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
|
||||||
from .architecture.SCUNet import SCUNet
|
|
||||||
from .architecture.SPSR import SPSRNet as SPSR
|
|
||||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
|
||||||
from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
|
|
||||||
from .architecture.Swin2SR import Swin2SR
|
|
||||||
from .architecture.SwinIR import SwinIR
|
|
||||||
|
|
||||||
PyTorchSRModels = (
|
|
||||||
RealESRGANv2,
|
|
||||||
SPSR,
|
|
||||||
SwiftSRGAN,
|
|
||||||
ESRGAN,
|
|
||||||
SwinIR,
|
|
||||||
Swin2SR,
|
|
||||||
HAT,
|
|
||||||
OmniSR,
|
|
||||||
SCUNet,
|
|
||||||
DAT,
|
|
||||||
)
|
|
||||||
PyTorchSRModel = Union[
|
|
||||||
RealESRGANv2,
|
|
||||||
SPSR,
|
|
||||||
SwiftSRGAN,
|
|
||||||
ESRGAN,
|
|
||||||
SwinIR,
|
|
||||||
Swin2SR,
|
|
||||||
HAT,
|
|
||||||
OmniSR,
|
|
||||||
SCUNet,
|
|
||||||
DAT,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_pytorch_sr_model(model: object):
|
|
||||||
return isinstance(model, PyTorchSRModels)
|
|
||||||
|
|
||||||
|
|
||||||
PyTorchFaceModels = (GFPGANv1Clean, RestoreFormer, CodeFormer)
|
|
||||||
PyTorchFaceModel = Union[GFPGANv1Clean, RestoreFormer, CodeFormer]
|
|
||||||
|
|
||||||
|
|
||||||
def is_pytorch_face_model(model: object):
|
|
||||||
return isinstance(model, PyTorchFaceModels)
|
|
||||||
|
|
||||||
|
|
||||||
PyTorchInpaintModels = (LaMa,)
|
|
||||||
PyTorchInpaintModel = Union[LaMa]
|
|
||||||
|
|
||||||
|
|
||||||
def is_pytorch_inpaint_model(model: object):
|
|
||||||
return isinstance(model, PyTorchInpaintModels)
|
|
||||||
|
|
||||||
|
|
||||||
PyTorchModels = (*PyTorchSRModels, *PyTorchFaceModels, *PyTorchInpaintModels)
|
|
||||||
PyTorchModel = Union[PyTorchSRModel, PyTorchFaceModel, PyTorchInpaintModel]
|
|
||||||
|
|
||||||
|
|
||||||
def is_pytorch_model(model: object):
|
|
||||||
return isinstance(model, PyTorchModels)
|
|
||||||
61
comfy_extras/nodes_advanced_samplers.py
Normal file
61
comfy_extras/nodes_advanced_samplers.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import comfy.samplers
|
||||||
|
import comfy.utils
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm.auto import trange, tqdm
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable=None, total_upscale=2.0, upscale_method="bislerp", upscale_steps=None):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
|
||||||
|
if upscale_steps is None:
|
||||||
|
upscale_steps = max(len(sigmas) // 2 + 1, 2)
|
||||||
|
else:
|
||||||
|
upscale_steps += 1
|
||||||
|
upscale_steps = min(upscale_steps, len(sigmas) + 1)
|
||||||
|
|
||||||
|
upscales = np.linspace(1.0, total_upscale, upscale_steps)[1:]
|
||||||
|
|
||||||
|
orig_shape = x.size()
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
|
x = denoised
|
||||||
|
if i < len(upscales):
|
||||||
|
x = comfy.utils.common_upscale(x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled")
|
||||||
|
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x += sigmas[i + 1] * torch.randn_like(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerLCMUpscale:
|
||||||
|
upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
|
||||||
|
"scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
|
||||||
|
"upscale_method": (s.upscale_methods,),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SAMPLER",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/samplers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sampler"
|
||||||
|
|
||||||
|
def get_sampler(self, scale_ratio, scale_steps, upscale_method):
|
||||||
|
if scale_steps < 0:
|
||||||
|
scale_steps = None
|
||||||
|
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
||||||
|
return (sampler, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"SamplerLCMUpscale": SamplerLCMUpscale,
|
||||||
|
}
|
||||||
@ -25,6 +25,7 @@ class AlignYourStepsScheduler:
|
|||||||
return {"required":
|
return {"required":
|
||||||
{"model_type": (["SD1", "SDXL", "SVD"], ),
|
{"model_type": (["SD1", "SDXL", "SVD"], ),
|
||||||
"steps": ("INT", {"default": 10, "min": 10, "max": 10000}),
|
"steps": ("INT", {"default": 10, "min": 10, "max": 10000}),
|
||||||
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("SIGMAS",)
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
@ -32,11 +33,18 @@ class AlignYourStepsScheduler:
|
|||||||
|
|
||||||
FUNCTION = "get_sigmas"
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
def get_sigmas(self, model_type, steps):
|
def get_sigmas(self, model_type, steps, denoise):
|
||||||
|
total_steps = steps
|
||||||
|
if denoise < 1.0:
|
||||||
|
if denoise <= 0.0:
|
||||||
|
return (torch.FloatTensor([]),)
|
||||||
|
total_steps = round(steps * denoise)
|
||||||
|
|
||||||
sigmas = NOISE_LEVELS[model_type][:]
|
sigmas = NOISE_LEVELS[model_type][:]
|
||||||
if (steps + 1) != len(sigmas):
|
if (steps + 1) != len(sigmas):
|
||||||
sigmas = loglinear_interp(sigmas, steps + 1)
|
sigmas = loglinear_interp(sigmas, steps + 1)
|
||||||
|
|
||||||
|
sigmas = sigmas[-(total_steps + 1):]
|
||||||
sigmas[-1] = 0
|
sigmas[-1] = 0
|
||||||
return (torch.FloatTensor(sigmas), )
|
return (torch.FloatTensor(sigmas), )
|
||||||
|
|
||||||
|
|||||||
120
comfy_extras/nodes_attention_multiply.py
Normal file
120
comfy_extras/nodes_attention_multiply.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
|
||||||
|
def attention_multiply(attn, model, q, k, v, out):
|
||||||
|
m = model.clone()
|
||||||
|
sd = model.model_state_dict()
|
||||||
|
|
||||||
|
for key in sd:
|
||||||
|
if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)):
|
||||||
|
m.add_patches({key: (None,)}, 0.0, q)
|
||||||
|
if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)):
|
||||||
|
m.add_patches({key: (None,)}, 0.0, k)
|
||||||
|
if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)):
|
||||||
|
m.add_patches({key: (None,)}, 0.0, v)
|
||||||
|
if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)):
|
||||||
|
m.add_patches({key: (None,)}, 0.0, out)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
class UNetSelfAttentionMultiply:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/attention_experiments"
|
||||||
|
|
||||||
|
def patch(self, model, q, k, v, out):
|
||||||
|
m = attention_multiply("attn1", model, q, k, v, out)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
class UNetCrossAttentionMultiply:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/attention_experiments"
|
||||||
|
|
||||||
|
def patch(self, model, q, k, v, out):
|
||||||
|
m = attention_multiply("attn2", model, q, k, v, out)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
class CLIPAttentionMultiply:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip": ("CLIP",),
|
||||||
|
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/attention_experiments"
|
||||||
|
|
||||||
|
def patch(self, clip, q, k, v, out):
|
||||||
|
m = clip.clone()
|
||||||
|
sd = m.patcher.model_state_dict()
|
||||||
|
|
||||||
|
for key in sd:
|
||||||
|
if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"):
|
||||||
|
m.add_patches({key: (None,)}, 0.0, q)
|
||||||
|
if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"):
|
||||||
|
m.add_patches({key: (None,)}, 0.0, k)
|
||||||
|
if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"):
|
||||||
|
m.add_patches({key: (None,)}, 0.0, v)
|
||||||
|
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
||||||
|
m.add_patches({key: (None,)}, 0.0, out)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
class UNetTemporalAttentionMultiply:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/attention_experiments"
|
||||||
|
|
||||||
|
def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal):
|
||||||
|
m = model.clone()
|
||||||
|
sd = model.model_state_dict()
|
||||||
|
|
||||||
|
for k in sd:
|
||||||
|
if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")):
|
||||||
|
if '.time_stack.' in k:
|
||||||
|
m.add_patches({k: (None,)}, 0.0, self_temporal)
|
||||||
|
else:
|
||||||
|
m.add_patches({k: (None,)}, 0.0, self_structural)
|
||||||
|
elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")):
|
||||||
|
if '.time_stack.' in k:
|
||||||
|
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
||||||
|
else:
|
||||||
|
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"UNetSelfAttentionMultiply": UNetSelfAttentionMultiply,
|
||||||
|
"UNetCrossAttentionMultiply": UNetCrossAttentionMultiply,
|
||||||
|
"CLIPAttentionMultiply": CLIPAttentionMultiply,
|
||||||
|
"UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply,
|
||||||
|
}
|
||||||
128
comfy_extras/nodes_audio.py
Normal file
128
comfy_extras/nodes_audio.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import torchaudio
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import folder_paths
|
||||||
|
import os
|
||||||
|
|
||||||
|
class EmptyLatentAudio:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
batch_size = 1
|
||||||
|
latent = torch.zeros([batch_size, 64, 1024], device=self.device)
|
||||||
|
return ({"samples":latent, "type": "audio"}, )
|
||||||
|
|
||||||
|
class VAEEncodeAudio:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
def encode(self, vae, audio):
|
||||||
|
t = vae.encode(audio["waveform"].movedim(1, -1))
|
||||||
|
return ({"samples":t}, )
|
||||||
|
|
||||||
|
class VAEDecodeAudio:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||||
|
RETURN_TYPES = ("AUDIO",)
|
||||||
|
FUNCTION = "decode"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
def decode(self, vae, samples):
|
||||||
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
|
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||||
|
|
||||||
|
class SaveAudio:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
self.type = "output"
|
||||||
|
self.prefix_append = ""
|
||||||
|
self.compress_level = 4
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "audio": ("AUDIO", ),
|
||||||
|
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save_audio"
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
|
filename_prefix += self.prefix_append
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
results = list()
|
||||||
|
for (batch_number, waveform) in enumerate(audio["waveform"]):
|
||||||
|
#TODO: metadata
|
||||||
|
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||||
|
file = f"{filename_with_batch_num}_{counter:05}_.flac"
|
||||||
|
torchaudio.save(os.path.join(full_output_folder, file), waveform, audio["sample_rate"], format="FLAC")
|
||||||
|
results.append({
|
||||||
|
"filename": file,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": self.type
|
||||||
|
})
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
return { "ui": { "audio": results } }
|
||||||
|
|
||||||
|
class LoadAudio:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
|
return {"required": {"audio": [sorted(files), ]}, }
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/audio"
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO", )
|
||||||
|
FUNCTION = "load"
|
||||||
|
|
||||||
|
def load(self, audio):
|
||||||
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
|
multiplier = 1.0
|
||||||
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
|
return (audio, )
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(s, audio):
|
||||||
|
image_path = folder_paths.get_annotated_filepath(audio)
|
||||||
|
m = hashlib.sha256()
|
||||||
|
with open(image_path, 'rb') as f:
|
||||||
|
m.update(f.read())
|
||||||
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, audio):
|
||||||
|
if not folder_paths.exists_annotated_filepath(audio):
|
||||||
|
return "Invalid audio file: {}".format(audio)
|
||||||
|
return True
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"EmptyLatentAudio": EmptyLatentAudio,
|
||||||
|
"VAEEncodeAudio": VAEEncodeAudio,
|
||||||
|
"VAEDecodeAudio": VAEDecodeAudio,
|
||||||
|
"SaveAudio": SaveAudio,
|
||||||
|
"LoadAudio": LoadAudio,
|
||||||
|
}
|
||||||
@ -1,10 +1,5 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import comfy.model_management
|
|
||||||
|
|
||||||
from kornia.filters import canny
|
from kornia.filters import canny
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
class Canny:
|
class Canny:
|
||||||
|
|||||||
@ -28,6 +28,14 @@ class PorterDuffMode(Enum):
|
|||||||
|
|
||||||
|
|
||||||
def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
|
def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
|
||||||
|
# convert mask to alpha
|
||||||
|
src_alpha = 1 - src_alpha
|
||||||
|
dst_alpha = 1 - dst_alpha
|
||||||
|
# premultiply alpha
|
||||||
|
src_image = src_image * src_alpha
|
||||||
|
dst_image = dst_image * dst_alpha
|
||||||
|
|
||||||
|
# composite ops below assume alpha-premultiplied images
|
||||||
if mode == PorterDuffMode.ADD:
|
if mode == PorterDuffMode.ADD:
|
||||||
out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
|
out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
|
||||||
out_image = torch.clamp(src_image + dst_image, 0, 1)
|
out_image = torch.clamp(src_image + dst_image, 0, 1)
|
||||||
@ -35,7 +43,7 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
|||||||
out_alpha = torch.zeros_like(dst_alpha)
|
out_alpha = torch.zeros_like(dst_alpha)
|
||||||
out_image = torch.zeros_like(dst_image)
|
out_image = torch.zeros_like(dst_image)
|
||||||
elif mode == PorterDuffMode.DARKEN:
|
elif mode == PorterDuffMode.DARKEN:
|
||||||
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
||||||
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
|
||||||
elif mode == PorterDuffMode.DST:
|
elif mode == PorterDuffMode.DST:
|
||||||
out_alpha = dst_alpha
|
out_alpha = dst_alpha
|
||||||
@ -84,8 +92,13 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
|||||||
out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
|
out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
|
||||||
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
|
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
|
||||||
else:
|
else:
|
||||||
out_alpha = None
|
return None, None
|
||||||
out_image = None
|
|
||||||
|
# back to non-premultiplied alpha
|
||||||
|
out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image))
|
||||||
|
out_image = torch.clamp(out_image, 0, 1)
|
||||||
|
# convert alpha to mask
|
||||||
|
out_alpha = 1 - out_alpha
|
||||||
return out_image, out_alpha
|
return out_image, out_alpha
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -39,8 +39,8 @@ class KarrasScheduler:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {"required":
|
||||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
"rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
"rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -58,8 +58,8 @@ class ExponentialScheduler:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {"required":
|
||||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("SIGMAS",)
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
@ -76,8 +76,8 @@ class PolyexponentialScheduler:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {"required":
|
||||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
"rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
"rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,8 +107,7 @@ class SDTurboScheduler:
|
|||||||
def get_sigmas(self, model, steps, denoise):
|
def get_sigmas(self, model, steps, denoise):
|
||||||
start_step = 10 - int(10 * denoise)
|
start_step = 10 - int(10 * denoise)
|
||||||
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
|
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
|
||||||
comfy.model_management.load_models_gpu([model])
|
sigmas = model.get_model_object("model_sampling").sigma(timesteps)
|
||||||
sigmas = model.model.model_sampling.sigma(timesteps)
|
|
||||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
@ -117,8 +116,8 @@ class VPScheduler:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {"required":
|
||||||
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
"beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), #TODO: fix default values
|
"beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values
|
||||||
"beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
|
"beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
|
||||||
"eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}),
|
"eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -140,6 +139,7 @@ class SplitSigmas:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("SIGMAS","SIGMAS")
|
RETURN_TYPES = ("SIGMAS","SIGMAS")
|
||||||
|
RETURN_NAMES = ("high_sigmas", "low_sigmas")
|
||||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||||
|
|
||||||
FUNCTION = "get_sigmas"
|
FUNCTION = "get_sigmas"
|
||||||
@ -149,6 +149,27 @@ class SplitSigmas:
|
|||||||
sigmas2 = sigmas[step:]
|
sigmas2 = sigmas[step:]
|
||||||
return (sigmas1, sigmas2)
|
return (sigmas1, sigmas2)
|
||||||
|
|
||||||
|
class SplitSigmasDenoise:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"sigmas": ("SIGMAS", ),
|
||||||
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS","SIGMAS")
|
||||||
|
RETURN_NAMES = ("high_sigmas", "low_sigmas")
|
||||||
|
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||||
|
|
||||||
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
|
def get_sigmas(self, sigmas, denoise):
|
||||||
|
steps = max(sigmas.shape[-1] - 1, 0)
|
||||||
|
total_steps = round(steps * denoise)
|
||||||
|
sigmas1 = sigmas[:-(total_steps)]
|
||||||
|
sigmas2 = sigmas[-(total_steps + 1):]
|
||||||
|
return (sigmas1, sigmas2)
|
||||||
|
|
||||||
class FlipSigmas:
|
class FlipSigmas:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -359,6 +380,10 @@ class SamplerCustom:
|
|||||||
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
|
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
|
||||||
latent = latent_image
|
latent = latent_image
|
||||||
latent_image = latent["samples"]
|
latent_image = latent["samples"]
|
||||||
|
latent = latent.copy()
|
||||||
|
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
|
||||||
|
latent["samples"] = latent_image
|
||||||
|
|
||||||
if not add_noise:
|
if not add_noise:
|
||||||
noise = Noise_EmptyNoise().generate_noise(latent)
|
noise = Noise_EmptyNoise().generate_noise(latent)
|
||||||
else:
|
else:
|
||||||
@ -517,6 +542,9 @@ class SamplerCustomAdvanced:
|
|||||||
def sample(self, noise, guider, sampler, sigmas, latent_image):
|
def sample(self, noise, guider, sampler, sigmas, latent_image):
|
||||||
latent = latent_image
|
latent = latent_image
|
||||||
latent_image = latent["samples"]
|
latent_image = latent["samples"]
|
||||||
|
latent = latent.copy()
|
||||||
|
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
|
||||||
|
latent["samples"] = latent_image
|
||||||
|
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
@ -600,6 +628,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||||
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
"SamplerDPMAdaptative": SamplerDPMAdaptative,
|
||||||
"SplitSigmas": SplitSigmas,
|
"SplitSigmas": SplitSigmas,
|
||||||
|
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||||
"FlipSigmas": FlipSigmas,
|
"FlipSigmas": FlipSigmas,
|
||||||
|
|
||||||
"CFGGuider": CFGGuider,
|
"CFGGuider": CFGGuider,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user