1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-09 12:30:38 +08:00

optimize(infer): move modules into rvc

This commit is contained in:
源文雨
2024-06-08 00:14:03 +09:00
parent 44725ddd2c
commit eb24434260
8 changed files with 468 additions and 618 deletions

View File

@@ -9,15 +9,15 @@ from torch.nn import functional as F
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
channels: int,
out_channels: int,
n_heads: int,
p_dropout: float = 0.0,
window_size: int | None = None,
heads_share: bool = True,
block_length: int | None = None,
proximal_bias: bool = False,
proximal_init: bool = False,
):
super(MultiHeadAttention, self).__init__()
assert channels % n_heads == 0
@@ -60,19 +60,30 @@ class MultiHeadAttention(nn.Module):
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def __call__(
self,
x: torch.Tensor,
c: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().__call__(x, c, attn_mask=attn_mask)
def forward(
self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
):
self,
x: torch.Tensor,
c: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, _ = self.attention(q, k, v, mask=attn_mask)
x, _ = self._attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(
def _attention(
self,
query: torch.Tensor,
key: torch.Tensor,
@@ -149,7 +160,7 @@ class MultiHeadAttention(nn.Module):
return ret
def _get_relative_embeddings(self, relative_embeddings, length: int):
max_relative_position = 2 * self.window_size + 1
# max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length: int = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
@@ -217,13 +228,13 @@ class MultiHeadAttention(nn.Module):
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation: str = None,
causal=False,
in_channels: int,
out_channels: int,
filter_channels: int,
kernel_size: int,
p_dropout: float = 0.0,
activation: str | None = None,
causal: bool = False,
):
super(FFN, self).__init__()
self.in_channels = in_channels
@@ -234,32 +245,29 @@ class FFN(nn.Module):
self.activation = activation
self.causal = causal
self.is_activation = True if activation == "gelu" else False
# if causal:
# self.padding = self._causal_padding
# else:
# self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
if self.causal:
padding = self._causal_padding(x * x_mask)
else:
padding = self._same_padding(x * x_mask)
return padding
def __call__(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
return super().__call__(x, x_mask)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
x = self.conv_1(self.padding(x, x_mask))
def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
x = self.conv_1(self._padding(x, x_mask))
if self.is_activation:
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x, x_mask))
x = self.conv_2(self._padding(x, x_mask))
return x * x_mask
def _padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
if self.causal:
return self._causal_padding(x * x_mask)
return self._same_padding(x * x_mask)
def _causal_padding(self, x):
if self.kernel_size == 1: