1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-11 05:30:30 +08:00

optimize(rvc): move commons to rvc.utils

- remove redundant attentions_onnx
- shrink models_onnx
- add some type note to rvc.utils
This commit is contained in:
源文雨
2024-06-07 00:42:35 +09:00
parent 6f90ce3046
commit 5eed789fe7
8 changed files with 186 additions and 1477 deletions

View File

@@ -1,13 +1,10 @@
import copy
import math
from typing import Optional
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from infer.lib.infer_pack import commons, modules
from infer.lib.infer_pack.modules import LayerNorm
@@ -76,7 +73,7 @@ class Encoder(nn.Module):
x = x * x_mask
return x
"""
class Decoder(nn.Module):
def __init__(
self,
@@ -138,11 +135,9 @@ class Decoder(nn.Module):
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
# x: decoder input
# h: encoder output
self_attn_mask = utils.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
@@ -161,7 +156,7 @@ class Decoder(nn.Module):
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
"""
class MultiHeadAttention(nn.Module):
def __init__(
@@ -342,7 +337,7 @@ class MultiHeadAttention(nn.Module):
x_flat = F.pad(
x_flat,
# commons.convert_pad_shape([[0, 0], [0, 0], [0, int(length) - 1]])
[0, int(length) - 1, 0, 0, 0, 0],
[0, length - 1, 0, 0, 0, 0],
)
# Reshape and slice out the padded elements.
@@ -361,9 +356,9 @@ class MultiHeadAttention(nn.Module):
x = F.pad(
x,
# commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, int(length) - 1]])
[0, int(length) - 1, 0, 0, 0, 0, 0, 0],
[0, length - 1, 0, 0, 0, 0, 0, 0],
)
x_flat = x.view([batch, heads, int(length**2) + int(length * (length - 1))])
x_flat = x.view([batch, heads, (length**2) + (length * (length - 1))])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(
x_flat,