From b67050b2f7f5bfeaaec804691d90d84fbd1101d6 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 10 Jun 2024 01:14:46 +0900 Subject: [PATCH] chore(format): run black on dev (#19) Co-authored-by: github-actions[bot] --- infer/modules/train/train.py | 12 +++++- rvc/discriminators.py | 72 ++++++++++++++++++++++++------------ 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index c511fcd..1ba6b18 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -304,7 +304,17 @@ def run(rank, n_gpus, hps, logger: logging.Logger): def train_and_evaluate( - rank, epoch, hps, nets: Tuple[RVC_Model_f0, MultiPeriodDiscriminator], optims, schedulers, scaler, loaders, logger, writers, cache + rank, + epoch, + hps, + nets: Tuple[RVC_Model_f0, MultiPeriodDiscriminator], + optims, + schedulers, + scaler, + loaders, + logger, + writers, + cache, ): net_g, net_d = nets optim_g, optim_d = optims diff --git a/rvc/discriminators.py b/rvc/discriminators.py index 784265a..29e2dff 100644 --- a/rvc/discriminators.py +++ b/rvc/discriminators.py @@ -14,21 +14,41 @@ class MultiPeriodDiscriminator(torch.nn.Module): """ version: 'v1' or 'v2' """ - def __init__(self, version: str, use_spectral_norm: bool = False, has_xpu: bool = False): + + def __init__( + self, version: str, use_spectral_norm: bool = False, has_xpu: bool = False + ): super(MultiPeriodDiscriminator, self).__init__() - periods = (2, 3, 5, 7, 11, 17) if version == "v1" else (2, 3, 5, 7, 11, 17, 23, 37) + periods = ( + (2, 3, 5, 7, 11, 17) if version == "v1" else (2, 3, 5, 7, 11, 17, 23, 37) + ) - self.discriminators = nn.ModuleList([ - DiscriminatorS(use_spectral_norm=use_spectral_norm), - *( - DiscriminatorP(i, use_spectral_norm=use_spectral_norm, has_xpu=has_xpu) for i in periods - ) - ]) + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=use_spectral_norm), + *( + DiscriminatorP( + i, use_spectral_norm=use_spectral_norm, has_xpu=has_xpu + ) + for i in periods + ), + ] + ) - def __call__(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + def __call__(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: return super().__call__(y, y_hat) - def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: y_d_rs = [] y_d_gs = [] fmap_rs = [] @@ -97,25 +117,29 @@ class DiscriminatorP(torch.nn.Module): convs_padding = (get_padding(kernel_size, 1), 0) self.convs = nn.ModuleList() - for i in range(len(sequence)-1): - self.convs.append(norm_f( + for i in range(len(sequence) - 1): + self.convs.append( + norm_f( + Conv2d( + sequence[i], + sequence[i + 1], + (kernel_size, 1), + (stride, 1), + padding=convs_padding, + ) + ) + ) + self.convs.append( + norm_f( Conv2d( - sequence[i], - sequence[i + 1], + 1024, + 1024, (kernel_size, 1), - (stride, 1), + 1, padding=convs_padding, ) - )) - self.convs.append(norm_f( - Conv2d( - 1024, - 1024, - (kernel_size, 1), - 1, - padding=convs_padding, ) - )) + ) self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: