mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
68 lines
1.7 KiB
Python
68 lines
1.7 KiB
Python
from typing import Tuple
|
|
|
|
import torch.nn as nn
|
|
|
|
from .deepunet import DeepUnet
|
|
|
|
|
|
class E2E(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_blocks: int,
|
|
n_gru: int,
|
|
kernel_size: Tuple[int, int],
|
|
en_de_layers=5,
|
|
inter_layers=4,
|
|
in_channels=1,
|
|
en_out_channels=16,
|
|
):
|
|
super(E2E, self).__init__()
|
|
|
|
self.unet = DeepUnet(
|
|
kernel_size,
|
|
n_blocks,
|
|
en_de_layers,
|
|
inter_layers,
|
|
in_channels,
|
|
en_out_channels,
|
|
)
|
|
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
|
if n_gru:
|
|
self.fc = nn.Sequential(
|
|
self.BiGRU(3 * 128, 256, n_gru),
|
|
nn.Linear(512, 360),
|
|
nn.Dropout(0.25),
|
|
nn.Sigmoid(),
|
|
)
|
|
else:
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(3 * nn.N_MELS, nn.N_CLASS),
|
|
nn.Dropout(0.25),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, mel):
|
|
mel = mel.transpose(-1, -2).unsqueeze(1)
|
|
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
class BiGRU(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_features: int,
|
|
hidden_features: int,
|
|
num_layers: int,
|
|
):
|
|
super().__init__()
|
|
self.gru = nn.GRU(
|
|
input_features,
|
|
hidden_features,
|
|
num_layers=num_layers,
|
|
batch_first=True,
|
|
bidirectional=True,
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.gru(x)[0]
|