1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

optimize(infer.synthesizer): all modules inherit from one

This commit is contained in:
源文雨
2024-06-10 21:34:35 +09:00
parent b67050b2f7
commit e33ef19200
5 changed files with 127 additions and 231 deletions

View File

@@ -33,8 +33,9 @@ class SynthesizerTrnMsNSFsid(nn.Module):
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr: str | int,
sr: Optional[str | int],
encoder_dim: int,
use_f0: bool,
):
super(SynthesizerTrnMs256NSFsid, self).__init__()
if isinstance(sr, str):
@@ -59,8 +60,8 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.gin_channels = gin_channels
# self.hop_length = hop_length#
self.spk_embed_dim = spk_embed_dim
self.enc_p = TextEncoder(
encoder_dim,
inter_channels,
@@ -70,18 +71,31 @@ class SynthesizerTrnMsNSFsid(nn.Module):
n_layers,
kernel_size,
float(p_dropout),
f0=use_f0,
)
self.dec = NSFGenerator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
sr=sr,
)
if use_f0:
self.dec = NSFGenerator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
sr=sr,
)
else:
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
@@ -133,11 +147,11 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self,
phone: torch.Tensor,
phone_lengths: torch.Tensor,
pitch: torch.Tensor,
pitchf: torch.Tensor,
y: torch.Tensor,
y_lengths: torch.Tensor,
ds: Optional[torch.Tensor] = None,
pitch: Optional[torch.Tensor] = None,
pitchf: Optional[torch.Tensor] = None,
): # 这里ds是id[bs,1]
# print(1,pitch.shape)#[bs,t]
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t广播的
@@ -147,10 +161,11 @@ class SynthesizerTrnMsNSFsid(nn.Module):
z_slice, ids_slice = rand_slice_segments_on_last_dim(
z, y_lengths, self.segment_size
)
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
# print(-2,pitchf.shape,z_slice.shape)
o = self.dec(z_slice, pitchf, g=g)
if pitchf is not None:
pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
o = self.dec(z_slice, pitchf, g=g)
else:
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
@torch.jit.export
@@ -158,17 +173,15 @@ class SynthesizerTrnMsNSFsid(nn.Module):
self,
phone: torch.Tensor,
phone_lengths: torch.Tensor,
pitch: torch.Tensor,
nsff0: torch.Tensor,
sid: torch.Tensor,
pitch: Optional[torch.Tensor] = None,
nsff0: Optional[torch.Tensor] = None,
skip_head: Optional[torch.Tensor] = None,
return_length: Optional[torch.Tensor] = None,
# return_length2: Optional[torch.Tensor] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
if skip_head is not None and return_length is not None:
assert isinstance(skip_head, torch.Tensor)
assert isinstance(return_length, torch.Tensor)
head = int(skip_head.item())
length = int(return_length.item())
flow_head = torch.clamp(skip_head - 24, min=0)
@@ -178,18 +191,28 @@ class SynthesizerTrnMsNSFsid(nn.Module):
z = self.flow(z_p, x_mask, g=g, reverse=True)
z = z[:, :, dec_head : dec_head + length]
x_mask = x_mask[:, :, dec_head : dec_head + length]
nsff0 = nsff0[:, head : head + length]
if nsff0 is not None:
nsff0 = nsff0[:, head : head + length]
else:
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(
z * x_mask,
nsff0,
g=g,
# n_res=return_length2,
)
return o, x_mask, (z, z_p, m_p, logs_p)
del z_p, m_p, logs_p
if nsff0 is not None:
o = self.dec(
z * x_mask,
nsff0,
g=g,
# n_res=return_length2,
)
else:
o = self.dec(
z * x_mask,
g=g,
# n_res=return_length2
)
del x_mask, z
return o # , x_mask, (z, z_p, m_p, logs_p)
class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
@@ -234,6 +257,7 @@ class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
gin_channels,
sr,
256,
True,
)
@@ -279,10 +303,11 @@ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid):
gin_channels,
sr,
768,
True,
)
class SynthesizerTrnMs256NSFsid_nono(nn.Module):
class SynthesizerTrnMs256NSFsid_nono(SynthesizerTrnMsNSFsid):
def __init__(
self,
spec_channels: int,
@@ -304,162 +329,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
gin_channels: int,
sr=None,
):
super(SynthesizerTrnMs256NSFsid_nono, self).__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = float(p_dropout)
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.gin_channels = gin_channels
# self.hop_length = hop_length#
self.spk_embed_dim = spk_embed_dim
self.enc_p = TextEncoder(
256,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
float(p_dropout),
f0=False,
)
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
def remove_weight_norm(self):
self.dec.remove_weight_norm()
self.flow.remove_weight_norm()
if hasattr(self, "enc_q"):
self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.dec)
for hook in self.flow._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.flow)
if hasattr(self, "enc_q"):
for hook in self.enc_q._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
torch.nn.utils.remove_weight_norm(self.enc_q)
return self
@torch.jit.ignore
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id[bs,1]
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t广播的
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
z_slice, ids_slice = rand_slice_segments_on_last_dim(
z, y_lengths, self.segment_size
)
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
@torch.jit.export
def infer(
self,
phone: torch.Tensor,
phone_lengths: torch.Tensor,
sid: torch.Tensor,
skip_head: Optional[torch.Tensor] = None,
return_length: Optional[torch.Tensor] = None,
# return_length2: Optional[torch.Tensor] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
if skip_head is not None and return_length is not None:
assert isinstance(skip_head, torch.Tensor)
assert isinstance(return_length, torch.Tensor)
head = int(skip_head.item())
length = int(return_length.item())
flow_head = torch.clamp(skip_head - 24, min=0)
dec_head = head - int(flow_head.item())
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths, flow_head)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
z = self.flow(z_p, x_mask, g=g, reverse=True)
z = z[:, :, dec_head : dec_head + length]
x_mask = x_mask[:, :, dec_head : dec_head + length]
else:
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = self.dec(
z * x_mask,
g=g,
# n_res=return_length2
)
return o, x_mask, (z, z_p, m_p, logs_p)
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
def __init__(
self,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr=None,
):
super(SynthesizerTrnMs768NSFsid_nono, self).__init__(
super().__init__(
spec_channels,
segment_size,
inter_channels,
@@ -477,16 +347,51 @@ class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
256,
False,
)
del self.enc_p
self.enc_p = TextEncoder(
768,
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMsNSFsid):
def __init__(
self,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
resblock: str,
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
upsample_rates: List[int],
upsample_initial_channel: int,
upsample_kernel_sizes: List[int],
spk_embed_dim: int,
gin_channels: int,
sr=None,
):
super().__init__(
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
float(p_dropout),
f0=False,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
spk_embed_dim,
gin_channels,
768,
False,
)

View File

@@ -399,6 +399,7 @@ class RVC:
p_len = input_wav.shape[0] // 160
factor = pow(2, self.formant_shift / 12)
return_length2 = int(np.ceil(return_length * factor))
cache_pitch = cache_pitchf = None
if self.if_f0 == 1:
f0_extractor_frame = block_frame_16k + 800
if f0method == "rmvpe":
@@ -424,25 +425,18 @@ class RVC:
p_len = torch.LongTensor([p_len]).to(self.device)
sid = torch.LongTensor([0]).to(self.device)
skip_head = torch.LongTensor([skip_head])
return_length2 = torch.LongTensor([return_length2])
# return_length2 = torch.LongTensor([return_length2])
return_length = torch.LongTensor([return_length])
with torch.no_grad():
if self.if_f0 == 1:
infered_audio, _, _ = self.net_g.infer(
feats,
p_len,
cache_pitch,
cache_pitchf,
sid,
skip_head,
return_length,
return_length2,
)
else:
infered_audio, _, _ = self.net_g.infer(
feats, p_len, sid, skip_head, return_length, return_length2
)
infered_audio = infered_audio.squeeze(1).float()
infered_audio = self.net_g.infer(
feats,
p_len,
sid,
pitch=cache_pitch,
pitchf=cache_pitchf,
skip_head=skip_head,
return_length=return_length,
).squeeze(1).float()
upp_res = int(np.floor(factor * self.tgt_sr // 100))
if upp_res != self.tgt_sr // 100:
if upp_res not in self.resample_kernel:

View File

@@ -415,6 +415,7 @@ def train_and_evaluate(
for batch_idx, info in data_iterator:
# Data
## Unpack
pitch = pitchf = None
if hps.if_f0 == 1:
(
phone,
@@ -444,22 +445,13 @@ def train_and_evaluate(
# Calculate
with autocast(enabled=hps.train.fp16_run):
if hps.if_f0 == 1:
(
y_hat,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
else:
(
y_hat,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
(
y_hat,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(phone, phone_lengths, spec, spec_lengths, sid, pitch, pitchf)
mel = spec_to_mel_torch(
spec,
hps.data.filter_length,

View File

@@ -290,10 +290,15 @@ class Pipeline(object):
feats = feats.to(feats0.dtype)
p_len = torch.tensor([p_len], device=self.device).long()
with torch.no_grad():
hasp = pitch is not None and pitchf is not None
arg = (feats, p_len, pitch, pitchf, sid) if hasp else (feats, p_len, sid)
audio1 = (net_g.infer(*arg)[0][0, 0]).data.cpu().float().numpy()
del arg
audio1 = (
net_g.infer(
feats,
p_len,
sid,
pitch=pitch,
pitchf=pitchf,
)[0, 0]
).data.cpu().float().numpy()
del feats, p_len, padding_mask
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@@ -183,7 +183,7 @@ for idx, name in enumerate(
pitchf = torch.FloatTensor(pitchf).unsqueeze(0).to(device)
with torch.no_grad():
audio = (
net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]
net_g.infer(feats, p_len, sid, pitch=pitch, pitchf=pitchf)[0, 0]
.data.cpu()
.float()
.numpy()