mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
22 lines
577 B
Python
22 lines
577 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Scale(nn.Module):
|
|
"""A learnable scale parameter.
|
|
|
|
This layer scales the input by a learnable factor. It multiplies a
|
|
learnable scale parameter of shape (1,) with input of any shape.
|
|
|
|
Args:
|
|
scale (float): Initial value of scale factor. Default: 1.0
|
|
"""
|
|
|
|
def __init__(self, scale=1.0):
|
|
super(Scale, self).__init__()
|
|
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
|
|
|
def forward(self, x):
|
|
return x * self.scale
|