1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-08 03:55:47 +08:00

optimize(infer.synthesizer): all modules inherit from one

This commit is contained in:
源文雨
2024-06-10 21:34:35 +09:00
parent b67050b2f7
commit e33ef19200
5 changed files with 127 additions and 231 deletions

View File

@@ -33,8 +33,9 @@ class SynthesizerTrnMsNSFsid(nn.Module):
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr: str | int,
sr: Optional[str | int],
encoder_dim: int,
use_f0: bool,
):
super(SynthesizerTrnMs256NSFsid, self).__init__()
if isinstance(sr, str):
@@ -59,8 +60,8 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.gin_channels = gin_channels
# self.hop_length = hop_length#
self.spk_embed_dim = spk_embed_dim
self.enc_p = TextEncoder(
encoder_dim,
inter_channels,
@@ -70,18 +71,31 @@ class SynthesizerTrnMsNSFsid(nn.Module):
n_layers,
kernel_size,
float(p_dropout),
f0=use_f0,
)
self.dec = NSFGenerator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
sr=sr,
)
if use_f0:
self.dec = NSFGenerator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
sr=sr,
)
else:
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
@@ -133,11 +147,11 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self,
phone: torch.Tensor,
phone_lengths: torch.Tensor,
pitch: torch.Tensor,
pitchf: torch.Tensor,
y: torch.Tensor,
y_lengths: torch.Tensor,
ds: Optional[torch.Tensor] = None,
pitch: Optional[torch.Tensor] = None,
pitchf: Optional[torch.Tensor] = None,
): # 这里ds是id[bs,1]
# print(1,pitch.shape)#[bs,t]
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t广播的
@@ -147,10 +161,11 @@ class SynthesizerTrnMsNSFsid(nn.Module):
z_slice, ids_slice = rand_slice_segments_on_last_dim(
z, y_lengths, self.segment_size
)
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
# print(-2,pitchf.shape,z_slice.shape)
o = self.dec(z_slice, pitchf, g=g)
if pitchf is not None:
pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
o = self.dec(z_slice, pitchf, g=g)
else:
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
@torch.jit.export
@@ -158,17 +173,15 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self,
phone: torch.Tensor,
phone_lengths: torch.Tensor,
pitch: torch.Tensor,
nsff0: torch.Tensor,
sid: torch.Tensor,
pitch: Optional[torch.Tensor] = None,
nsff0: Optional[torch.Tensor] = None,
skip_head: Optional[torch.Tensor] = None,
return_length: Optional[torch.Tensor] = None,
# return_length2: Optional[torch.Tensor] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
if skip_head is not None and return_length is not None:
assert isinstance(skip_head, torch.Tensor)
assert isinstance(return_length, torch.Tensor)
head = int(skip_head.item())
length = int(return_length.item())
flow_head = torch.clamp(skip_head - 24, min=0)
@@ -178,18 +191,28 @@ class SynthesizerTrnMsNSFsid(nn.Module):
z = self.flow(z_p, x_mask, g=g, reverse=True)
z = z[:, :, dec_head : dec_head + length]
x_mask = x_mask[:, :, dec_head : dec_head + length]
nsff0 = nsff0[:, head : head + length]
if nsff0 is not None:
nsff0 = nsff0[:, head : head + length]
else:
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(
z * x_mask,
nsff0,
g=g,
# n_res=return_length2,
)
return o, x_mask, (z, z_p, m_p, logs_p)
del z_p, m_p, logs_p
if nsff0 is not None:
o = self.dec(
z * x_mask,
nsff0,
g=g,
# n_res=return_length2,
)
else:
o = self.dec(
z * x_mask,
g=g,
# n_res=return_length2
)
del x_mask, z
return o # , x_mask, (z, z_p, m_p, logs_p)
class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
@@ -234,6 +257,7 @@ class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
gin_channels,
sr,
256,
True,
)
@@ -279,10 +303,11 @@ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid):
gin_channels,
sr,
768,
True,
)
class SynthesizerTrnMs256NSFsid_nono(nn.Module):
class SynthesizerTrnMs256NSFsid_nono(SynthesizerTrnMsNSFsid):
def __init__(
self,
spec_channels: int,
@@ -304,162 +329,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
gin_channels: int,
sr=None,
):
super(SynthesizerTrnMs256NSFsid_nono, self).__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = float(p_dropout)
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.gin_channels = gin_channels
# self.hop_length = hop_length#
self.spk_embed_dim = spk_embed_dim
self.enc_p = TextEncoder(
256,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
float(p_dropout),
f0=False,
)
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
def remove_weight_norm(self):
self.dec.remove_weight_norm()
self.flow.remove_weight_norm()
if hasattr(self, "enc_q"):
self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.dec)
for hook in self.flow._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flow)
if hasattr(self, "enc_q"):
for hook in self.enc_q._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc_q)
return self
@torch.jit.ignore
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id[bs,1]
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t广播的
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
z_slice, ids_slice = rand_slice_segments_on_last_dim(
z, y_lengths, self.segment_size
)
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
@torch.jit.export
def infer(
self,
phone: torch.Tensor,
phone_lengths: torch.Tensor,
sid: torch.Tensor,
skip_head: Optional[torch.Tensor] = None,
return_length: Optional[torch.Tensor] = None,
# return_length2: Optional[torch.Tensor] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
if skip_head is not None and return_length is not None:
assert isinstance(skip_head, torch.Tensor)
assert isinstance(return_length, torch.Tensor)
head = int(skip_head.item())
length = int(return_length.item())
flow_head = torch.clamp(skip_head - 24, min=0)
dec_head = head - int(flow_head.item())
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths, flow_head)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
z = self.flow(z_p, x_mask, g=g, reverse=True)
z = z[:, :, dec_head : dec_head + length]
x_mask = x_mask[:, :, dec_head : dec_head + length]
else:
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(
z * x_mask,
g=g,
# n_res=return_length2
)
return o, x_mask, (z, z_p, m_p, logs_p)
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
def __init__(
self,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr=None,
):
super(SynthesizerTrnMs768NSFsid_nono, self).__init__(
super().__init__(
spec_channels,
segment_size,
inter_channels,
@@ -477,16 +347,51 @@ class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
256,
False,
)
del self.enc_p
self.enc_p = TextEncoder(
768,
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMsNSFsid):
def __init__(
self,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr=None,
):
super().__init__(
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
float(p_dropout),
f0=False,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
768,
False,
)