1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-09 12:30:38 +08:00

768VecOnnxExport (#328)

* Delete export_onnx.py

* Delete export_onnx_old.py

* Delete models_onnx_moess.py

* Support 768 Vec

* Add files via upload

* Support 768 Vec

Support 768 Vec

* Support 768 Vec Onnx Export

Support 768 Vec Onnx Export
This commit is contained in:
Ναρουσέ·μ·γιουμεμί·Χινακάννα
2023-05-21 19:11:29 +08:00
committed by GitHub
parent c3de24f2e0
commit 067731db9b
5 changed files with 144 additions and 1046 deletions

View File

@@ -61,7 +61,7 @@ class TextEncoder256(nn.Module):
return m, logs, x_mask
class TextEncoder256Sim(nn.Module):
class TextEncoder768(nn.Module):
def __init__(
self,
out_channels,
@@ -81,14 +81,14 @@ class TextEncoder256Sim(nn.Module):
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb_phone = nn.Linear(256, hidden_channels)
self.emb_phone = nn.Linear(768, hidden_channels)
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
if f0 == True:
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, phone, pitch, lengths):
if pitch == None:
@@ -102,8 +102,10 @@ class TextEncoder256Sim(nn.Module):
x.dtype
)
x = self.encoder(x * x_mask, x_mask)
x = self.proj(x) * x_mask
return x, x_mask
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
@@ -527,7 +529,7 @@ sr2sr = {
}
class SynthesizerTrnMs256NSFsidO(nn.Module):
class SynthesizerTrnMsNSFsidM(nn.Module):
def __init__(
self,
spec_channels,
@@ -571,15 +573,26 @@ class SynthesizerTrnMs256NSFsidO(nn.Module):
self.gin_channels = gin_channels
# self.hop_length = hop_length#
self.spk_embed_dim = spk_embed_dim
self.enc_p = TextEncoder256(
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
)
if self.gin_channels == 256:
self.enc_p = TextEncoder256(
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
)
else:
self.enc_p = TextEncoder768(
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
)
self.dec = GeneratorNSF(
inter_channels,
resblock,
@@ -605,6 +618,7 @@ class SynthesizerTrnMs256NSFsidO(nn.Module):
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
self.speaker_map = None
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
def remove_weight_norm(self):
@@ -612,10 +626,24 @@ class SynthesizerTrnMs256NSFsidO(nn.Module):
self.flow.remove_weight_norm()
self.enc_q.remove_weight_norm()
def forward(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
g = self.emb_g(sid).unsqueeze(-1)
def construct_spkmixmap(self, n_speaker):
self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels))
for i in range(n_speaker):
self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
self.speaker_map = self.speaker_map.unsqueeze(0)
def forward(self, phone, phone_lengths, pitch, nsff0, g, rnd, max_len=None):
if self.speaker_map is not None: # [N, S] * [S, B, 1, H]
g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
g = g * self.speaker_map # [N, S, B, 1, H]
g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
else:
g = g.unsqueeze(0)
g = self.emb_g(g).transpose(1,2)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
return o
@@ -651,6 +679,36 @@ class MultiPeriodDiscriminator(torch.nn.Module):
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class MultiPeriodDiscriminatorV2(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminatorV2, self).__init__()
# periods = [2, 3, 5, 7, 11, 17]
periods = [2, 3, 5, 7, 11, 17, 23, 37]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = [] #
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
# for j in range(len(fmap_r)):
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()