1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-07 19:40:44 +08:00

optimize(infer): move attentions into rvc

This commit is contained in:
源文雨
2024-06-07 20:28:05 +09:00
parent 978abd8aac
commit 96604e8175
12 changed files with 195 additions and 341 deletions

View File

@@ -1,432 +0,0 @@
import math
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
from infer.lib.infer_pack.modules import LayerNorm
class Encoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=10,
):
super(Encoder, self).__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = int(n_layers)
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def __call__(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
return super().__call__(x, x_mask)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
zippep = zip(
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
)
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zippep:
y = attn_layers(x, x, attn_mask)
y = self.drop(y)
x = norm_layers_1(x + y)
y = ffn_layers(x, x_mask)
y = self.drop(y)
x = norm_layers_2(x + y)
x = x * x_mask
return x
"""
class Decoder(nn.Module):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
):
super(Decoder, self).__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
# x: decoder input
# h: encoder output
self_attn_mask = utils.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
"""
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super(MultiHeadAttention, self).__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(
self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, _ = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s = key.size()
t_t = query.size(2)
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length: int):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length: int = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
[0, 0, pad_length, pad_length, 0, 0],
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0], )
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0])
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
x_flat = x.view([batch, heads, (length**2) + (length * (length - 1))])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length: int):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation: str = None,
causal=False,
):
super(FFN, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
self.is_activation = True if activation == "gelu" else False
# if causal:
# self.padding = self._causal_padding
# else:
# self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
if self.causal:
padding = self._causal_padding(x * x_mask)
else:
padding = self._same_padding(x * x_mask)
return padding
def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
x = self.conv_1(self.padding(x, x_mask))
if self.is_activation:
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x, x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l: int = self.kernel_size - 1
pad_r: int = 0
# padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, [pad_l, pad_r, 0, 0, 0, 0])
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l: int = (self.kernel_size - 1) // 2
pad_r: int = self.kernel_size // 2
# padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, [pad_l, pad_r, 0, 0, 0, 0])
return x

View File

@@ -2,104 +2,19 @@ import math
import logging
from typing import Optional, Tuple, List
from rvc import utils
logger = logging.getLogger(__name__)
import torch
from torch import nn
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from infer.lib.infer_pack import attentions, modules
from rvc.utils import get_padding, call_weight_data_normal_if_Conv
from infer.lib.infer_pack import modules
from rvc.utils import get_padding, call_weight_data_normal_if_Conv, sequence_mask, slice_on_last_dim, rand_slice_segments_on_last_dim
from rvc.encoders import TextEncoder
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class TextEncoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
f0=True,
):
super(TextEncoder, self).__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = float(p_dropout)
self.emb_phone = nn.Linear(in_channels, hidden_channels)
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
if f0 == True:
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
self.encoder = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
float(p_dropout),
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def __call__(
self,
phone: torch.Tensor,
pitch: torch.Tensor,
lengths: torch.Tensor,
# skip_head: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return super().__call__(
phone,
pitch,
lengths,
# skip_head=skip_head,
)
def forward(
self,
phone: torch.Tensor,
pitch: torch.Tensor,
lengths: torch.Tensor,
# skip_head: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if pitch is None:
x = self.emb_phone(phone)
else:
x = self.emb_phone(phone) + self.emb_pitch(pitch)
x = x * math.sqrt(self.hidden_channels) # [b, t, h]
x = self.lrelu(x)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(
utils.sequence_mask(
lengths,
x.size(2),
),
1,
).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
"""
if skip_head is not None:
assert isinstance(skip_head, torch.Tensor)
head = int(skip_head.item())
x = x[:, :, head:]
x_mask = x_mask[:, :, head:]
"""
stats: torch.Tensor = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(
self,
@@ -205,11 +120,7 @@ class PosteriorEncoder(nn.Module):
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x_mask = torch.unsqueeze(
utils.sequence_mask(
x_lengths,
x.size(2),
),
1,
sequence_mask(x_lengths, x.size(2)), 1,
).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
@@ -728,12 +639,6 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
logger.debug(
"gin_channels: "
+ str(gin_channels)
+ ", self.spk_embed_dim: "
+ str(self.spk_embed_dim)
)
def remove_weight_norm(self):
self.dec.remove_weight_norm()
@@ -783,9 +688,9 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
z_slice, ids_slice = utils.rand_slice_segments_on_last_dim(z, y_lengths, self.segment_size)
z_slice, ids_slice = rand_slice_segments_on_last_dim(z, y_lengths, self.segment_size)
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
pitchf = utils.slice_on_last_dim(pitchf, ids_slice, self.segment_size)
pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
# print(-2,pitchf.shape,z_slice.shape)
o = self.dec(z_slice, pitchf, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
@@ -962,12 +867,6 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
)
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
logger.debug(
"gin_channels: "
+ str(gin_channels)
+ ", self.spk_embed_dim: "
+ str(self.spk_embed_dim)
)
def remove_weight_norm(self):
self.dec.remove_weight_norm()
@@ -1007,7 +906,7 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
z_slice, ids_slice = utils.rand_slice_segments_on_last_dim(z, y_lengths, self.segment_size)
z_slice, ids_slice = rand_slice_segments_on_last_dim(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

View File

@@ -1,13 +1,14 @@
import torch
from torch import nn
from .attentions import (
TextEncoder,
from .models import (
ResidualCouplingBlock,
PosteriorEncoder,
GeneratorNSF,
)
from rvc.encoders import TextEncoder
class SynthesizerTrnMsNSFsidM(nn.Module):
def __init__(

View File

@@ -1,89 +1,19 @@
import copy
import math
from typing import Optional, Tuple
import numpy as np
import scipy
import torch
from torch import nn
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn import Conv1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from rvc import utils
from rvc.utils import get_padding, call_weight_data_normal_if_Conv
from rvc.utils import get_padding, call_weight_data_normal_if_Conv, activate_add_tanh_sigmoid_multiply
from rvc.transforms import piecewise_rational_quadratic_transform
from rvc.norms import LayerNorm
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super(LayerNorm, self).__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super(ConvReluNorm, self).__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = float(p_dropout)
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(float(p_dropout)))
for _ in range(n_layers - 1):
self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
@@ -203,7 +133,7 @@ class WN(torch.nn.Module):
else:
g_l = torch.zeros_like(x_in)
acts = utils.activate_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels)
acts = activate_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels)
acts = self.drop(acts)
res_skip_acts = res_skip_layer(acts)