diff --git a/infer/lib/jit/synthesizer.py b/infer/lib/jit/synthesizer.py index 5069fe9..b117989 100644 --- a/infer/lib/jit/synthesizer.py +++ b/infer/lib/jit/synthesizer.py @@ -1,5 +1,6 @@ import torch + def get_synthesizer_ckpt(cpt, device=torch.device("cpu")): from infer.lib.infer_pack.models import ( SynthesizerTrnMs256NSFsid, @@ -35,7 +36,9 @@ def get_synthesizer_ckpt(cpt, device=torch.device("cpu")): net_g.remove_weight_norm() return net_g, cpt + def get_synthesizer(pth_path, device=torch.device("cpu")): return get_synthesizer_ckpt( - torch.load(pth_path, map_location=torch.device("cpu")), device, + torch.load(pth_path, map_location=torch.device("cpu")), + device, ) diff --git a/infer/modules/vc/modules.py b/infer/modules/vc/modules.py index fe71180..d33a2bb 100644 --- a/infer/modules/vc/modules.py +++ b/infer/modules/vc/modules.py @@ -62,7 +62,9 @@ class VC: elif torch.backends.mps.is_available(): torch.mps.empty_cache() ###楼下不这么折腾清理不干净 - self.net_g, self.cpt = get_synthesizer_ckpt(self.cpt, self.config.device) + self.net_g, self.cpt = get_synthesizer_ckpt( + self.cpt, self.config.device + ) self.if_f0 = self.cpt.get("f0", 1) self.version = self.cpt.get("version", "v1") del self.net_g, self.cpt