1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-07 02:00:25 +08:00

optimize(rvc.utils): more type defs & rename

This commit is contained in:
源文雨
2024-06-07 19:33:45 +09:00
parent c10c527264
commit 49488dcae9
6 changed files with 41 additions and 75 deletions

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Iterator
import torch
@@ -11,7 +11,7 @@ def call_weight_data_normal_if_Conv(m: torch.nn.Module):
m.weight.data.normal_(mean, std)
def get_padding(kernel_size: int, dilation=1):
def get_padding(kernel_size: int, dilation=1) -> int:
return int((kernel_size * dilation - dilation) / 2)
@@ -30,7 +30,7 @@ def slice_on_last_dim(
return ret
def rand_slice_segments(
def rand_slice_segments_on_last_dim(
x: torch.Tensor,
x_lengths: int = None,
segment_size=4,
@@ -45,19 +45,16 @@ def rand_slice_segments(
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
def activate_add_tanh_sigmoid_multiply(
input_a: torch.Tensor, input_b: torch.Tensor, n_channels: int
) -> torch.Tensor:
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
t_act = torch.tanh(in_act[:, :n_channels, :])
s_act = torch.sigmoid(in_act[:, n_channels:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()
def sequence_mask(
length: torch.Tensor,
max_length: Optional[int] = None,
@@ -68,19 +65,16 @@ def sequence_mask(
return x.unsqueeze(0) < length.unsqueeze(1)
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
def total_grad_norm(
parameters: Iterator[torch.nn.Parameter], norm_type: float=2.0,
) -> float:
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0.0
total_norm = 0
for p in parameters:
if p.grad is None: continue
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm += float(param_norm.item()) ** norm_type
total_norm = total_norm ** (1.0 / norm_type)
return total_norm