1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-08 03:55:47 +08:00

optimize(rvc.utils): more type defs & rename

This commit is contained in:
源文雨
2024-06-07 19:33:45 +09:00
parent c10c527264
commit 49488dcae9
6 changed files with 41 additions and 75 deletions

View File

@@ -136,7 +136,7 @@ class DDSConv(nn.Module):
class WN(torch.nn.Module):
def __init__(
self,
hidden_channels,
hidden_channels: int,
kernel_size,
dilation_rate,
n_layers,
@@ -189,7 +189,6 @@ class WN(torch.nn.Module):
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
@@ -197,14 +196,14 @@ class WN(torch.nn.Module):
for i, (in_layer, res_skip_layer) in enumerate(
zip(self.in_layers, self.res_skip_layers)
):
x_in = in_layer(x)
x_in: torch.Tensor = in_layer(x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = utils.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = utils.activate_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels)
acts = self.drop(acts)
res_skip_acts = res_skip_layer(acts)