From 1a4cb9294e342c6e56b883167e217fcb440dda3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Mon, 10 Jun 2024 22:03:57 +0900 Subject: [PATCH] optimize(infer): move syns into rvc --- infer/lib/jit/synthesizer.py | 29 +++++-------------- infer/modules/train/train.py | 6 ++-- .../models.py => rvc/synthesizers.py | 0 tools/cmd/infer-pm-index256.py | 2 +- 4 files changed, 11 insertions(+), 26 deletions(-) rename infer/lib/infer_pack/models.py => rvc/synthesizers.py (100%) diff --git a/infer/lib/jit/synthesizer.py b/infer/lib/jit/synthesizer.py index a27180a..02e90bb 100644 --- a/infer/lib/jit/synthesizer.py +++ b/infer/lib/jit/synthesizer.py @@ -1,35 +1,20 @@ import torch +from rvc.synthesizers import SynthesizerTrnMsNSFsid + def get_synthesizer_ckpt(cpt, device=torch.device("cpu")): - from infer.lib.infer_pack.models import ( - SynthesizerTrnMs256NSFsid, - SynthesizerTrnMs256NSFsid_nono, - SynthesizerTrnMs768NSFsid, - SynthesizerTrnMs768NSFsid_nono, - ) - - # tgt_sr = cpt["config"][-1] cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] if_f0 = cpt.get("f0", 1) version = cpt.get("version", "v1") if version == "v1": - if if_f0 == 1: - net_g = SynthesizerTrnMs256NSFsid(*cpt["config"]) - else: - net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) + encoder_dim = 256 elif version == "v2": - if if_f0 == 1: - net_g = SynthesizerTrnMs768NSFsid(*cpt["config"]) - else: - net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) + encoder_dim = 768 + net_g = SynthesizerTrnMsNSFsid( + *cpt["config"], encoder_dim=encoder_dim, use_f0 = if_f0==1, + ) del net_g.enc_q - # net_g.forward = net_g.infer - # ckpt = {} - # ckpt["config"] = cpt["config"] - # ckpt["f0"] = if_f0 - # ckpt["version"] = version - # ckpt["info"] = cpt.get("info", "0epoch") net_g.load_state_dict(cpt["weight"], strict=False) net_g = net_g.float() net_g.eval().to(device) diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index 7881687..85abb25 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -59,12 +59,12 @@ from infer.lib.train.data_utils import ( from rvc.discriminators import MultiPeriodDiscriminator if hps.version == "v1": - from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0 - from infer.lib.infer_pack.models import ( + from rvc.synthesizers import SynthesizerTrnMs256NSFsid as RVC_Model_f0 + from rvc.synthesizers import ( SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0, ) else: - from infer.lib.infer_pack.models import ( + from rvc.synthesizers import ( SynthesizerTrnMs768NSFsid as RVC_Model_f0, SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0, ) diff --git a/infer/lib/infer_pack/models.py b/rvc/synthesizers.py similarity index 100% rename from infer/lib/infer_pack/models.py rename to rvc/synthesizers.py diff --git a/tools/cmd/infer-pm-index256.py b/tools/cmd/infer-pm-index256.py index 9a84c44..5a56cb5 100644 --- a/tools/cmd/infer-pm-index256.py +++ b/tools/cmd/infer-pm-index256.py @@ -24,7 +24,7 @@ from fairseq import checkpoint_utils # from models import SynthesizerTrn256#hifigan_nonsf # from lib.infer_pack.models import SynthesizerTrn256NSF as SynthesizerTrn256#hifigan_nsf -from infer.lib.infer_pack.models import ( +from rvc.synthesizers import ( SynthesizerTrnMs256NSFsid as SynthesizerTrn256, ) # hifigan_nsf from scipy.io import wavfile