mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
optimize(infer): move syns into rvc
This commit is contained in:
@@ -1,35 +1,20 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from rvc.synthesizers import SynthesizerTrnMsNSFsid
|
||||||
|
|
||||||
|
|
||||||
def get_synthesizer_ckpt(cpt, device=torch.device("cpu")):
|
def get_synthesizer_ckpt(cpt, device=torch.device("cpu")):
|
||||||
from infer.lib.infer_pack.models import (
|
|
||||||
SynthesizerTrnMs256NSFsid,
|
|
||||||
SynthesizerTrnMs256NSFsid_nono,
|
|
||||||
SynthesizerTrnMs768NSFsid,
|
|
||||||
SynthesizerTrnMs768NSFsid_nono,
|
|
||||||
)
|
|
||||||
|
|
||||||
# tgt_sr = cpt["config"][-1]
|
|
||||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||||||
if_f0 = cpt.get("f0", 1)
|
if_f0 = cpt.get("f0", 1)
|
||||||
version = cpt.get("version", "v1")
|
version = cpt.get("version", "v1")
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
if if_f0 == 1:
|
encoder_dim = 256
|
||||||
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"])
|
|
||||||
else:
|
|
||||||
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
if if_f0 == 1:
|
encoder_dim = 768
|
||||||
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"])
|
net_g = SynthesizerTrnMsNSFsid(
|
||||||
else:
|
*cpt["config"], encoder_dim=encoder_dim, use_f0 = if_f0==1,
|
||||||
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
)
|
||||||
del net_g.enc_q
|
del net_g.enc_q
|
||||||
# net_g.forward = net_g.infer
|
|
||||||
# ckpt = {}
|
|
||||||
# ckpt["config"] = cpt["config"]
|
|
||||||
# ckpt["f0"] = if_f0
|
|
||||||
# ckpt["version"] = version
|
|
||||||
# ckpt["info"] = cpt.get("info", "0epoch")
|
|
||||||
net_g.load_state_dict(cpt["weight"], strict=False)
|
net_g.load_state_dict(cpt["weight"], strict=False)
|
||||||
net_g = net_g.float()
|
net_g = net_g.float()
|
||||||
net_g.eval().to(device)
|
net_g.eval().to(device)
|
||||||
|
|||||||
@@ -59,12 +59,12 @@ from infer.lib.train.data_utils import (
|
|||||||
from rvc.discriminators import MultiPeriodDiscriminator
|
from rvc.discriminators import MultiPeriodDiscriminator
|
||||||
|
|
||||||
if hps.version == "v1":
|
if hps.version == "v1":
|
||||||
from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0
|
from rvc.synthesizers import SynthesizerTrnMs256NSFsid as RVC_Model_f0
|
||||||
from infer.lib.infer_pack.models import (
|
from rvc.synthesizers import (
|
||||||
SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
|
SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from infer.lib.infer_pack.models import (
|
from rvc.synthesizers import (
|
||||||
SynthesizerTrnMs768NSFsid as RVC_Model_f0,
|
SynthesizerTrnMs768NSFsid as RVC_Model_f0,
|
||||||
SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
|
SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from fairseq import checkpoint_utils
|
|||||||
|
|
||||||
# from models import SynthesizerTrn256#hifigan_nonsf
|
# from models import SynthesizerTrn256#hifigan_nonsf
|
||||||
# from lib.infer_pack.models import SynthesizerTrn256NSF as SynthesizerTrn256#hifigan_nsf
|
# from lib.infer_pack.models import SynthesizerTrn256NSF as SynthesizerTrn256#hifigan_nsf
|
||||||
from infer.lib.infer_pack.models import (
|
from rvc.synthesizers import (
|
||||||
SynthesizerTrnMs256NSFsid as SynthesizerTrn256,
|
SynthesizerTrnMs256NSFsid as SynthesizerTrn256,
|
||||||
) # hifigan_nsf
|
) # hifigan_nsf
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
|||||||
Reference in New Issue
Block a user