mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-09 12:30:38 +08:00
optimize(infer): move modules into rvc
This commit is contained in:
@@ -9,15 +9,15 @@ from torch.nn import functional as F
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
out_channels,
|
||||
n_heads,
|
||||
p_dropout=0.0,
|
||||
window_size=None,
|
||||
heads_share=True,
|
||||
block_length=None,
|
||||
proximal_bias=False,
|
||||
proximal_init=False,
|
||||
channels: int,
|
||||
out_channels: int,
|
||||
n_heads: int,
|
||||
p_dropout: float = 0.0,
|
||||
window_size: int | None = None,
|
||||
heads_share: bool = True,
|
||||
block_length: int | None = None,
|
||||
proximal_bias: bool = False,
|
||||
proximal_init: bool = False,
|
||||
):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
assert channels % n_heads == 0
|
||||
@@ -60,19 +60,30 @@ class MultiHeadAttention(nn.Module):
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return super().__call__(x, c, attn_mask=attn_mask)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
|
||||
):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, _ = self.attention(q, k, v, mask=attn_mask)
|
||||
x, _ = self._attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(
|
||||
def _attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -149,7 +160,7 @@ class MultiHeadAttention(nn.Module):
|
||||
return ret
|
||||
|
||||
def _get_relative_embeddings(self, relative_embeddings, length: int):
|
||||
max_relative_position = 2 * self.window_size + 1
|
||||
# max_relative_position = 2 * self.window_size + 1
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_length: int = max(length - (self.window_size + 1), 0)
|
||||
slice_start_position = max((self.window_size + 1) - length, 0)
|
||||
@@ -217,13 +228,13 @@ class MultiHeadAttention(nn.Module):
|
||||
class FFN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=0.0,
|
||||
activation: str = None,
|
||||
causal=False,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
filter_channels: int,
|
||||
kernel_size: int,
|
||||
p_dropout: float = 0.0,
|
||||
activation: str | None = None,
|
||||
causal: bool = False,
|
||||
):
|
||||
super(FFN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
@@ -234,32 +245,29 @@ class FFN(nn.Module):
|
||||
self.activation = activation
|
||||
self.causal = causal
|
||||
self.is_activation = True if activation == "gelu" else False
|
||||
# if causal:
|
||||
# self.padding = self._causal_padding
|
||||
# else:
|
||||
# self.padding = self._same_padding
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
|
||||
if self.causal:
|
||||
padding = self._causal_padding(x * x_mask)
|
||||
else:
|
||||
padding = self._same_padding(x * x_mask)
|
||||
return padding
|
||||
def __call__(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
|
||||
return super().__call__(x, x_mask)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
|
||||
x = self.conv_1(self.padding(x, x_mask))
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv_1(self._padding(x, x_mask))
|
||||
if self.is_activation:
|
||||
x = x * torch.sigmoid(1.702 * x)
|
||||
else:
|
||||
x = torch.relu(x)
|
||||
x = self.drop(x)
|
||||
|
||||
x = self.conv_2(self.padding(x, x_mask))
|
||||
x = self.conv_2(self._padding(x, x_mask))
|
||||
return x * x_mask
|
||||
|
||||
def _padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
|
||||
if self.causal:
|
||||
return self._causal_padding(x * x_mask)
|
||||
return self._same_padding(x * x_mask)
|
||||
|
||||
def _causal_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
|
||||
Reference in New Issue
Block a user