From 978abd8aac119daa663694793c284339220a0b28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 7 Jun 2024 19:53:23 +0900 Subject: [PATCH] optimize(infer): move transforms into rvc --- infer/lib/infer_pack/modules.py | 4 +-- {infer/lib/infer_pack => rvc}/transforms.py | 38 ++++++++++----------- 2 files changed, 21 insertions(+), 21 deletions(-) rename {infer/lib/infer_pack => rvc}/transforms.py (88%) diff --git a/infer/lib/infer_pack/modules.py b/infer/lib/infer_pack/modules.py index 593c301..a94c510 100644 --- a/infer/lib/infer_pack/modules.py +++ b/infer/lib/infer_pack/modules.py @@ -12,7 +12,7 @@ from torch.nn.utils import remove_weight_norm, weight_norm from rvc import utils from rvc.utils import get_padding, call_weight_data_normal_if_Conv -from infer.lib.infer_pack.transforms import piecewise_rational_quadratic_transform +from rvc.transforms import piecewise_rational_quadratic_transform LRELU_SLOPE = 0.1 @@ -583,7 +583,7 @@ class ConvFlow(nn.Module): reverse=False, ): x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) + h: torch.Tensor = self.pre(x0) h = self.convs(h, x_mask, g=g) h = self.proj(h) * x_mask diff --git a/infer/lib/infer_pack/transforms.py b/rvc/transforms.py similarity index 88% rename from infer/lib/infer_pack/transforms.py rename to rvc/transforms.py index 6d07b3b..679882f 100644 --- a/infer/lib/infer_pack/transforms.py +++ b/rvc/transforms.py @@ -8,13 +8,13 @@ DEFAULT_MIN_DERIVATIVE = 1e-3 def piecewise_rational_quadratic_transform( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails=None, - tail_bound=1.0, + inputs: torch.Tensor, + unnormalized_widths: torch.Tensor, + unnormalized_heights: torch.Tensor, + unnormalized_derivatives: torch.Tensor, + inverse: bool = False, + tails: str | None = None, + tail_bound: float = 1.0, min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE, @@ -46,13 +46,13 @@ def searchsorted(bin_locations, inputs, eps=1e-6): def unconstrained_rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails="linear", - tail_bound=1.0, + inputs: torch.Tensor, + unnormalized_widths: torch.Tensor, + unnormalized_heights: torch.Tensor, + unnormalized_derivatives: torch.Tensor, + inverse: bool = False, + tails: str = "linear", + tail_bound: float = 1.0, min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE, @@ -96,11 +96,11 @@ def unconstrained_rational_quadratic_spline( def rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, + inputs: torch.Tensor, + unnormalized_widths: torch.Tensor, + unnormalized_heights: torch.Tensor, + unnormalized_derivatives: torch.Tensor, + inverse: bool = False, left=0.0, right=1.0, bottom=0.0,