1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-06 09:40:24 +08:00

fix(dml): train extract_f0_print error

ModuleNotFoundError: No module named 'torch.privateuseone' due to new prosess
This commit is contained in:
源文雨
2025-11-21 16:52:17 +08:00
parent 7fa122045f
commit 43d19eb00e
13 changed files with 50 additions and 39 deletions

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Iterator
from typing import List, Optional, Tuple, Iterator, Union
import torch
@@ -17,7 +17,7 @@ def get_padding(kernel_size: int, dilation=1) -> int:
def slice_on_last_dim(
x: torch.Tensor,
start_indices: List[int],
start_indices: Union[List[int], torch.Tensor],
segment_size=4,
) -> torch.Tensor:
new_shape = [*x.shape]
@@ -32,9 +32,9 @@ def slice_on_last_dim(
def rand_slice_segments_on_last_dim(
x: torch.Tensor,
x_lengths: int = None,
x_lengths: Optional[Union[int, torch.Tensor]] = None,
segment_size=4,
) -> Tuple[torch.Tensor, List[int]]:
) -> Tuple[torch.Tensor, Union[List[int], torch.Tensor]]:
b, _, t = x.size()
if x_lengths is None:
x_lengths = t
@@ -58,7 +58,7 @@ def activate_add_tanh_sigmoid_multiply(
def sequence_mask(
length: torch.Tensor,
max_length: Optional[int] = None,
) -> torch.BoolTensor:
):
if max_length is None:
max_length = int(length.max())
x = torch.arange(max_length, dtype=length.dtype, device=length.device)