liyin_code/models/stylegan2/op/fused_act.py

41 lines
1.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import torch
from torch import nn
from torch.nn import functional as F
module_path = os.path.dirname(__file__)
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(channel))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
rest_dim = [1] * (input.ndim - bias.ndim - 1)
input = input.cuda()
if input.ndim == 3:
return (
F.leaky_relu(
input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
)
* scale #增益值,激活函数里的 gaintorch中scale 是一个增益值,增益值是指的非线性函数稳态时输入幅度与输出幅度的比值,通常被用来乘在激活函数之后使激活函数更加稳定。
)
else:
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
)
* scale
)