1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

chore(format): run black on dev (#19)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2024-06-10 01:14:46 +09:00
committed by GitHub
parent b23ea7c6e7
commit b67050b2f7
2 changed files with 59 additions and 25 deletions

View File

@@ -304,7 +304,17 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
def train_and_evaluate(
rank, epoch, hps, nets: Tuple[RVC_Model_f0, MultiPeriodDiscriminator], optims, schedulers, scaler, loaders, logger, writers, cache
rank,
epoch,
hps,
nets: Tuple[RVC_Model_f0, MultiPeriodDiscriminator],
optims,
schedulers,
scaler,
loaders,
logger,
writers,
cache,
):
net_g, net_d = nets
optim_g, optim_d = optims

View File

@@ -14,21 +14,41 @@ class MultiPeriodDiscriminator(torch.nn.Module):
"""
version: 'v1' or 'v2'
"""
def __init__(self, version: str, use_spectral_norm: bool = False, has_xpu: bool = False):
super(MultiPeriodDiscriminator, self).__init__()
periods = (2, 3, 5, 7, 11, 17) if version == "v1" else (2, 3, 5, 7, 11, 17, 23, 37)
self.discriminators = nn.ModuleList([
def __init__(
self, version: str, use_spectral_norm: bool = False, has_xpu: bool = False
):
super(MultiPeriodDiscriminator, self).__init__()
periods = (
(2, 3, 5, 7, 11, 17) if version == "v1" else (2, 3, 5, 7, 11, 17, 23, 37)
)
self.discriminators = nn.ModuleList(
[
DiscriminatorS(use_spectral_norm=use_spectral_norm),
*(
DiscriminatorP(i, use_spectral_norm=use_spectral_norm, has_xpu=has_xpu) for i in periods
DiscriminatorP(
i, use_spectral_norm=use_spectral_norm, has_xpu=has_xpu
)
for i in periods
),
]
)
])
def __call__(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
def __call__(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
return super().__call__(y, y_hat)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
@@ -97,8 +117,9 @@ class DiscriminatorP(torch.nn.Module):
convs_padding = (get_padding(kernel_size, 1), 0)
self.convs = nn.ModuleList()
for i in range(len(sequence)-1):
self.convs.append(norm_f(
for i in range(len(sequence) - 1):
self.convs.append(
norm_f(
Conv2d(
sequence[i],
sequence[i + 1],
@@ -106,8 +127,10 @@ class DiscriminatorP(torch.nn.Module):
(stride, 1),
padding=convs_padding,
)
))
self.convs.append(norm_f(
)
)
self.convs.append(
norm_f(
Conv2d(
1024,
1024,
@@ -115,7 +138,8 @@ class DiscriminatorP(torch.nn.Module):
1,
padding=convs_padding,
)
))
)
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: