1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00
Files
Retrieval-based-Voice-Conve…/rvc/f0/e2e.py
github-actions[bot] 26d17cd714 chore(format): run black on dev (#43)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-06-12 21:07:24 +09:00

68 lines
1.7 KiB
Python

from typing import Tuple
import torch.nn as nn
from .deepunet import DeepUnet
class E2E(nn.Module):
def __init__(
self,
n_blocks: int,
n_gru: int,
kernel_size: Tuple[int, int],
en_de_layers=5,
inter_layers=4,
in_channels=1,
en_out_channels=16,
):
super(E2E, self).__init__()
self.unet = DeepUnet(
kernel_size,
n_blocks,
en_de_layers,
inter_layers,
in_channels,
en_out_channels,
)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
self.BiGRU(3 * 128, 256, n_gru),
nn.Linear(512, 360),
nn.Dropout(0.25),
nn.Sigmoid(),
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * nn.N_MELS, nn.N_CLASS),
nn.Dropout(0.25),
nn.Sigmoid(),
)
def forward(self, mel):
mel = mel.transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
x = self.fc(x)
return x
class BiGRU(nn.Module):
def __init__(
self,
input_features: int,
hidden_features: int,
num_layers: int,
):
super().__init__()
self.gru = nn.GRU(
input_features,
hidden_features,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
)
def forward(self, x):
return self.gru(x)[0]