import torch from torch import nn from torch.nn import functional as F class LayerNorm(nn.Module): def __init__(self, channels: int, eps: float = 1e-5): super(LayerNorm, self).__init__() self.channels = channels self.eps = eps self.gamma = nn.Parameter(torch.ones(channels)) self.beta = nn.Parameter(torch.zeros(channels)) def forward(self, x: torch.Tensor): x = x.transpose(1, -1) x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) return x.transpose(1, -1)