mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-07 02:00:25 +08:00
optimize(rvc.utils): more type defs & rename
This commit is contained in:
36
rvc/utils.py
36
rvc/utils.py
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Iterator
|
||||
|
||||
import torch
|
||||
|
||||
@@ -11,7 +11,7 @@ def call_weight_data_normal_if_Conv(m: torch.nn.Module):
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size: int, dilation=1):
|
||||
def get_padding(kernel_size: int, dilation=1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ def slice_on_last_dim(
|
||||
return ret
|
||||
|
||||
|
||||
def rand_slice_segments(
|
||||
def rand_slice_segments_on_last_dim(
|
||||
x: torch.Tensor,
|
||||
x_lengths: int = None,
|
||||
segment_size=4,
|
||||
@@ -45,19 +45,16 @@ def rand_slice_segments(
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
def activate_add_tanh_sigmoid_multiply(
|
||||
input_a: torch.Tensor, input_b: torch.Tensor, n_channels: int
|
||||
) -> torch.Tensor:
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
t_act = torch.tanh(in_act[:, :n_channels, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
|
||||
return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()
|
||||
|
||||
|
||||
def sequence_mask(
|
||||
length: torch.Tensor,
|
||||
max_length: Optional[int] = None,
|
||||
@@ -68,19 +65,16 @@ def sequence_mask(
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
def total_grad_norm(
|
||||
parameters: Iterator[torch.nn.Parameter], norm_type: float=2.0,
|
||||
) -> float:
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
clip_value = float(clip_value)
|
||||
total_norm = 0.0
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
if p.grad is None: continue
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm += float(param_norm.item()) ** norm_type
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
Reference in New Issue
Block a user