mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
optimize(infer.synthesizer): all modules inherit from one
This commit is contained in:
@@ -33,8 +33,9 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
upsample_kernel_sizes: List[int],
|
||||
spk_embed_dim: int,
|
||||
gin_channels: int,
|
||||
sr: str | int,
|
||||
sr: Optional[str | int],
|
||||
encoder_dim: int,
|
||||
use_f0: bool,
|
||||
):
|
||||
super(SynthesizerTrnMs256NSFsid, self).__init__()
|
||||
if isinstance(sr, str):
|
||||
@@ -59,8 +60,8 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.segment_size = segment_size
|
||||
self.gin_channels = gin_channels
|
||||
# self.hop_length = hop_length#
|
||||
self.spk_embed_dim = spk_embed_dim
|
||||
|
||||
self.enc_p = TextEncoder(
|
||||
encoder_dim,
|
||||
inter_channels,
|
||||
@@ -70,18 +71,31 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
n_layers,
|
||||
kernel_size,
|
||||
float(p_dropout),
|
||||
f0=use_f0,
|
||||
)
|
||||
self.dec = NSFGenerator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
sr=sr,
|
||||
)
|
||||
if use_f0:
|
||||
self.dec = NSFGenerator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
sr=sr,
|
||||
)
|
||||
else:
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels,
|
||||
inter_channels,
|
||||
@@ -133,11 +147,11 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
pitchf: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
y_lengths: torch.Tensor,
|
||||
ds: Optional[torch.Tensor] = None,
|
||||
pitch: Optional[torch.Tensor] = None,
|
||||
pitchf: Optional[torch.Tensor] = None,
|
||||
): # 这里ds是id,[bs,1]
|
||||
# print(1,pitch.shape)#[bs,t]
|
||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||||
@@ -147,10 +161,11 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
z_slice, ids_slice = rand_slice_segments_on_last_dim(
|
||||
z, y_lengths, self.segment_size
|
||||
)
|
||||
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
|
||||
pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
|
||||
# print(-2,pitchf.shape,z_slice.shape)
|
||||
o = self.dec(z_slice, pitchf, g=g)
|
||||
if pitchf is not None:
|
||||
pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
|
||||
o = self.dec(z_slice, pitchf, g=g)
|
||||
else:
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
@torch.jit.export
|
||||
@@ -158,17 +173,15 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
pitch: torch.Tensor,
|
||||
nsff0: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
pitch: Optional[torch.Tensor] = None,
|
||||
nsff0: Optional[torch.Tensor] = None,
|
||||
skip_head: Optional[torch.Tensor] = None,
|
||||
return_length: Optional[torch.Tensor] = None,
|
||||
# return_length2: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
if skip_head is not None and return_length is not None:
|
||||
assert isinstance(skip_head, torch.Tensor)
|
||||
assert isinstance(return_length, torch.Tensor)
|
||||
head = int(skip_head.item())
|
||||
length = int(return_length.item())
|
||||
flow_head = torch.clamp(skip_head - 24, min=0)
|
||||
@@ -178,18 +191,28 @@ class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
z = z[:, :, dec_head : dec_head + length]
|
||||
x_mask = x_mask[:, :, dec_head : dec_head + length]
|
||||
nsff0 = nsff0[:, head : head + length]
|
||||
if nsff0 is not None:
|
||||
nsff0 = nsff0[:, head : head + length]
|
||||
else:
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(
|
||||
z * x_mask,
|
||||
nsff0,
|
||||
g=g,
|
||||
# n_res=return_length2,
|
||||
)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
del z_p, m_p, logs_p
|
||||
if nsff0 is not None:
|
||||
o = self.dec(
|
||||
z * x_mask,
|
||||
nsff0,
|
||||
g=g,
|
||||
# n_res=return_length2,
|
||||
)
|
||||
else:
|
||||
o = self.dec(
|
||||
z * x_mask,
|
||||
g=g,
|
||||
# n_res=return_length2
|
||||
)
|
||||
del x_mask, z
|
||||
return o # , x_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
|
||||
class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
|
||||
@@ -234,6 +257,7 @@ class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
|
||||
gin_channels,
|
||||
sr,
|
||||
256,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
@@ -279,10 +303,11 @@ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid):
|
||||
gin_channels,
|
||||
sr,
|
||||
768,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
class SynthesizerTrnMs256NSFsid_nono(SynthesizerTrnMsNSFsid):
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels: int,
|
||||
@@ -304,162 +329,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||||
gin_channels: int,
|
||||
sr=None,
|
||||
):
|
||||
super(SynthesizerTrnMs256NSFsid_nono, self).__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.segment_size = segment_size
|
||||
self.gin_channels = gin_channels
|
||||
# self.hop_length = hop_length#
|
||||
self.spk_embed_dim = spk_embed_dim
|
||||
self.enc_p = TextEncoder(
|
||||
256,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
float(p_dropout),
|
||||
f0=False,
|
||||
)
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
5,
|
||||
1,
|
||||
16,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
||||
)
|
||||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.dec.remove_weight_norm()
|
||||
self.flow.remove_weight_norm()
|
||||
if hasattr(self, "enc_q"):
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
|
||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
z_slice, ids_slice = rand_slice_segments_on_last_dim(
|
||||
z, y_lengths, self.segment_size
|
||||
)
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
skip_head: Optional[torch.Tensor] = None,
|
||||
return_length: Optional[torch.Tensor] = None,
|
||||
# return_length2: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
if skip_head is not None and return_length is not None:
|
||||
assert isinstance(skip_head, torch.Tensor)
|
||||
assert isinstance(return_length, torch.Tensor)
|
||||
head = int(skip_head.item())
|
||||
length = int(return_length.item())
|
||||
flow_head = torch.clamp(skip_head - 24, min=0)
|
||||
dec_head = head - int(flow_head.item())
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths, flow_head)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
z = z[:, :, dec_head : dec_head + length]
|
||||
x_mask = x_mask[:, :, dec_head : dec_head + length]
|
||||
else:
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
o = self.dec(
|
||||
z * x_mask,
|
||||
g=g,
|
||||
# n_res=return_length2
|
||||
)
|
||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
|
||||
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels: int,
|
||||
segment_size: int,
|
||||
inter_channels: int,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int,
|
||||
p_dropout: int,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes: List[int],
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
upsample_rates: List[int],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: List[int],
|
||||
spk_embed_dim: int,
|
||||
gin_channels: int,
|
||||
sr=None,
|
||||
):
|
||||
super(SynthesizerTrnMs768NSFsid_nono, self).__init__(
|
||||
super().__init__(
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
@@ -477,16 +347,51 @@ class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
|
||||
upsample_kernel_sizes,
|
||||
spk_embed_dim,
|
||||
gin_channels,
|
||||
256,
|
||||
False,
|
||||
)
|
||||
del self.enc_p
|
||||
self.enc_p = TextEncoder(
|
||||
768,
|
||||
|
||||
|
||||
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMsNSFsid):
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels: int,
|
||||
segment_size: int,
|
||||
inter_channels: int,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int,
|
||||
p_dropout: int,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes: List[int],
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
upsample_rates: List[int],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: List[int],
|
||||
spk_embed_dim: int,
|
||||
gin_channels: int,
|
||||
sr=None,
|
||||
):
|
||||
super().__init__(
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
float(p_dropout),
|
||||
f0=False,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
spk_embed_dim,
|
||||
gin_channels,
|
||||
768,
|
||||
False,
|
||||
)
|
||||
|
||||
@@ -399,6 +399,7 @@ class RVC:
|
||||
p_len = input_wav.shape[0] // 160
|
||||
factor = pow(2, self.formant_shift / 12)
|
||||
return_length2 = int(np.ceil(return_length * factor))
|
||||
cache_pitch = cache_pitchf = None
|
||||
if self.if_f0 == 1:
|
||||
f0_extractor_frame = block_frame_16k + 800
|
||||
if f0method == "rmvpe":
|
||||
@@ -424,25 +425,18 @@ class RVC:
|
||||
p_len = torch.LongTensor([p_len]).to(self.device)
|
||||
sid = torch.LongTensor([0]).to(self.device)
|
||||
skip_head = torch.LongTensor([skip_head])
|
||||
return_length2 = torch.LongTensor([return_length2])
|
||||
# return_length2 = torch.LongTensor([return_length2])
|
||||
return_length = torch.LongTensor([return_length])
|
||||
with torch.no_grad():
|
||||
if self.if_f0 == 1:
|
||||
infered_audio, _, _ = self.net_g.infer(
|
||||
feats,
|
||||
p_len,
|
||||
cache_pitch,
|
||||
cache_pitchf,
|
||||
sid,
|
||||
skip_head,
|
||||
return_length,
|
||||
return_length2,
|
||||
)
|
||||
else:
|
||||
infered_audio, _, _ = self.net_g.infer(
|
||||
feats, p_len, sid, skip_head, return_length, return_length2
|
||||
)
|
||||
infered_audio = infered_audio.squeeze(1).float()
|
||||
infered_audio = self.net_g.infer(
|
||||
feats,
|
||||
p_len,
|
||||
sid,
|
||||
pitch=cache_pitch,
|
||||
pitchf=cache_pitchf,
|
||||
skip_head=skip_head,
|
||||
return_length=return_length,
|
||||
).squeeze(1).float()
|
||||
upp_res = int(np.floor(factor * self.tgt_sr // 100))
|
||||
if upp_res != self.tgt_sr // 100:
|
||||
if upp_res not in self.resample_kernel:
|
||||
|
||||
@@ -415,6 +415,7 @@ def train_and_evaluate(
|
||||
for batch_idx, info in data_iterator:
|
||||
# Data
|
||||
## Unpack
|
||||
pitch = pitchf = None
|
||||
if hps.if_f0 == 1:
|
||||
(
|
||||
phone,
|
||||
@@ -444,22 +445,13 @@ def train_and_evaluate(
|
||||
|
||||
# Calculate
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
if hps.if_f0 == 1:
|
||||
(
|
||||
y_hat,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
|
||||
else:
|
||||
(
|
||||
y_hat,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
|
||||
(
|
||||
y_hat,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
) = net_g(phone, phone_lengths, spec, spec_lengths, sid, pitch, pitchf)
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
hps.data.filter_length,
|
||||
|
||||
@@ -290,10 +290,15 @@ class Pipeline(object):
|
||||
feats = feats.to(feats0.dtype)
|
||||
p_len = torch.tensor([p_len], device=self.device).long()
|
||||
with torch.no_grad():
|
||||
hasp = pitch is not None and pitchf is not None
|
||||
arg = (feats, p_len, pitch, pitchf, sid) if hasp else (feats, p_len, sid)
|
||||
audio1 = (net_g.infer(*arg)[0][0, 0]).data.cpu().float().numpy()
|
||||
del arg
|
||||
audio1 = (
|
||||
net_g.infer(
|
||||
feats,
|
||||
p_len,
|
||||
sid,
|
||||
pitch=pitch,
|
||||
pitchf=pitchf,
|
||||
)[0, 0]
|
||||
).data.cpu().float().numpy()
|
||||
del feats, p_len, padding_mask
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -183,7 +183,7 @@ for idx, name in enumerate(
|
||||
pitchf = torch.FloatTensor(pitchf).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
audio = (
|
||||
net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]
|
||||
net_g.infer(feats, p_len, sid, pitch=pitch, pitchf=pitchf)[0, 0]
|
||||
.data.cpu()
|
||||
.float()
|
||||
.numpy()
|
||||
|
||||
Reference in New Issue
Block a user