mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
optimize(infer): move syns into rvc
This commit is contained in:
@@ -1,397 +0,0 @@
|
||||
from typing import Optional, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from rvc.residuals import ResidualCouplingBlock
|
||||
from rvc.utils import (
|
||||
slice_on_last_dim,
|
||||
rand_slice_segments_on_last_dim,
|
||||
)
|
||||
from rvc.encoders import TextEncoder, PosteriorEncoder
|
||||
from rvc.generators import Generator
|
||||
from rvc.nsf import NSFGenerator
|
||||
|
||||
|
||||
class SynthesizerTrnMsNSFsid(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels: int,
|
||||
segment_size: int,
|
||||
inter_channels: int,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int,
|
||||
p_dropout: int,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes: List[int],
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
upsample_rates: List[int],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: List[int],
|
||||
spk_embed_dim: int,
|
||||
gin_channels: int,
|
||||
sr: Optional[Union[str, int]],
|
||||
encoder_dim: int,
|
||||
use_f0: bool,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(sr, str):
|
||||
sr = {
|
||||
"32k": 32000,
|
||||
"40k": 40000,
|
||||
"48k": 48000,
|
||||
}[sr]
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = float(p_dropout)
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.segment_size = segment_size
|
||||
self.gin_channels = gin_channels
|
||||
self.spk_embed_dim = spk_embed_dim
|
||||
|
||||
self.enc_p = TextEncoder(
|
||||
encoder_dim,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
float(p_dropout),
|
||||
f0=use_f0,
|
||||
)
|
||||
if use_f0:
|
||||
self.dec = NSFGenerator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
sr=sr,
|
||||
)
|
||||
else:
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
5,
|
||||
1,
|
||||
16,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
||||
)
|
||||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.dec.remove_weight_norm()
|
||||
self.flow.remove_weight_norm()
|
||||
if hasattr(self, "enc_q"):
|
||||
self.enc_q.remove_weight_norm()
|
||||
|
||||
def __prepare_scriptable__(self):
|
||||
for hook in self.dec._forward_pre_hooks.values():
|
||||
# The hook we want to remove is an instance of WeightNorm class, so
|
||||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||||
# because of shadowing, so we check the module name directly.
|
||||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.dec)
|
||||
for hook in self.flow._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.flow)
|
||||
if hasattr(self, "enc_q"):
|
||||
for hook in self.enc_q._forward_pre_hooks.values():
|
||||
if (
|
||||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||||
and hook.__class__.__name__ == "WeightNorm"
|
||||
):
|
||||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||||
return self
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
y_lengths: torch.Tensor,
|
||||
ds: Optional[torch.Tensor] = None,
|
||||
pitch: Optional[torch.Tensor] = None,
|
||||
pitchf: Optional[torch.Tensor] = None,
|
||||
): # 这里ds是id,[bs,1]
|
||||
# print(1,pitch.shape)#[bs,t]
|
||||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
z_slice, ids_slice = rand_slice_segments_on_last_dim(
|
||||
z, y_lengths, self.segment_size
|
||||
)
|
||||
if pitchf is not None:
|
||||
pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
|
||||
o = self.dec(z_slice, pitchf, g=g)
|
||||
else:
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self,
|
||||
phone: torch.Tensor,
|
||||
phone_lengths: torch.Tensor,
|
||||
sid: torch.Tensor,
|
||||
pitch: Optional[torch.Tensor] = None,
|
||||
pitchf: Optional[torch.Tensor] = None, # nsff0
|
||||
skip_head: Optional[torch.Tensor] = None,
|
||||
return_length: Optional[torch.Tensor] = None,
|
||||
# return_length2: Optional[torch.Tensor] = None,
|
||||
):
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
if skip_head is not None and return_length is not None:
|
||||
head = int(skip_head.item())
|
||||
length = int(return_length.item())
|
||||
flow_head = torch.clamp(skip_head - 24, min=0)
|
||||
dec_head = head - int(flow_head.item())
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
z = z[:, :, dec_head : dec_head + length]
|
||||
x_mask = x_mask[:, :, dec_head : dec_head + length]
|
||||
if pitchf is not None:
|
||||
pitchf = pitchf[:, head : head + length]
|
||||
else:
|
||||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||
del z_p, m_p, logs_p
|
||||
if pitchf is not None:
|
||||
o = self.dec(
|
||||
z * x_mask,
|
||||
pitchf,
|
||||
g=g,
|
||||
# n_res=return_length2,
|
||||
)
|
||||
else:
|
||||
o = self.dec(
|
||||
z * x_mask,
|
||||
g=g,
|
||||
# n_res=return_length2
|
||||
)
|
||||
del x_mask, z
|
||||
return o # , x_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
|
||||
class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels: int,
|
||||
segment_size: int,
|
||||
inter_channels: int,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int,
|
||||
p_dropout: int,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes: List[int],
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
upsample_rates: List[int],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: List[int],
|
||||
spk_embed_dim: int,
|
||||
gin_channels: int,
|
||||
sr: Union[str, int],
|
||||
):
|
||||
super().__init__(
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
spk_embed_dim,
|
||||
gin_channels,
|
||||
sr,
|
||||
256,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid):
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels: int,
|
||||
segment_size: int,
|
||||
inter_channels: int,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int,
|
||||
p_dropout: int,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes: List[int],
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
upsample_rates: List[int],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: List[int],
|
||||
spk_embed_dim: int,
|
||||
gin_channels: int,
|
||||
sr: Union[str, int],
|
||||
):
|
||||
super().__init__(
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
spk_embed_dim,
|
||||
gin_channels,
|
||||
sr,
|
||||
768,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
class SynthesizerTrnMs256NSFsid_nono(SynthesizerTrnMsNSFsid):
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels: int,
|
||||
segment_size: int,
|
||||
inter_channels: int,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int,
|
||||
p_dropout: int,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes: List[int],
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
upsample_rates: List[int],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: List[int],
|
||||
spk_embed_dim: int,
|
||||
gin_channels: int,
|
||||
sr=None,
|
||||
):
|
||||
super().__init__(
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
spk_embed_dim,
|
||||
gin_channels,
|
||||
256,
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMsNSFsid):
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels: int,
|
||||
segment_size: int,
|
||||
inter_channels: int,
|
||||
hidden_channels: int,
|
||||
filter_channels: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
kernel_size: int,
|
||||
p_dropout: int,
|
||||
resblock: str,
|
||||
resblock_kernel_sizes: List[int],
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
upsample_rates: List[int],
|
||||
upsample_initial_channel: int,
|
||||
upsample_kernel_sizes: List[int],
|
||||
spk_embed_dim: int,
|
||||
gin_channels: int,
|
||||
sr=None,
|
||||
):
|
||||
super().__init__(
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
spk_embed_dim,
|
||||
gin_channels,
|
||||
768,
|
||||
False,
|
||||
)
|
||||
@@ -1,35 +1,20 @@
|
||||
import torch
|
||||
|
||||
from rvc.synthesizers import SynthesizerTrnMsNSFsid
|
||||
|
||||
|
||||
def get_synthesizer_ckpt(cpt, device=torch.device("cpu")):
|
||||
from infer.lib.infer_pack.models import (
|
||||
SynthesizerTrnMs256NSFsid,
|
||||
SynthesizerTrnMs256NSFsid_nono,
|
||||
SynthesizerTrnMs768NSFsid,
|
||||
SynthesizerTrnMs768NSFsid_nono,
|
||||
)
|
||||
|
||||
# tgt_sr = cpt["config"][-1]
|
||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||||
if_f0 = cpt.get("f0", 1)
|
||||
version = cpt.get("version", "v1")
|
||||
if version == "v1":
|
||||
if if_f0 == 1:
|
||||
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"])
|
||||
else:
|
||||
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
||||
encoder_dim = 256
|
||||
elif version == "v2":
|
||||
if if_f0 == 1:
|
||||
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"])
|
||||
else:
|
||||
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
||||
encoder_dim = 768
|
||||
net_g = SynthesizerTrnMsNSFsid(
|
||||
*cpt["config"], encoder_dim=encoder_dim, use_f0 = if_f0==1,
|
||||
)
|
||||
del net_g.enc_q
|
||||
# net_g.forward = net_g.infer
|
||||
# ckpt = {}
|
||||
# ckpt["config"] = cpt["config"]
|
||||
# ckpt["f0"] = if_f0
|
||||
# ckpt["version"] = version
|
||||
# ckpt["info"] = cpt.get("info", "0epoch")
|
||||
net_g.load_state_dict(cpt["weight"], strict=False)
|
||||
net_g = net_g.float()
|
||||
net_g.eval().to(device)
|
||||
|
||||
Reference in New Issue
Block a user