1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00

fix(train): mysterious importing order

This commit is contained in:
源文雨
2024-06-16 19:33:23 +09:00
parent add4642b7e
commit d9a116f4f7
3 changed files with 6 additions and 5 deletions

View File

@@ -1,8 +1,7 @@
from typing import Any, Optional, Union
from typing import Optional, Union
import numpy as np
import torch
from torchfcpe import spawn_bundled_infer_model
from .f0 import F0Predictor
@@ -24,6 +23,8 @@ class FCPE(F0Predictor):
device,
)
from torchfcpe import spawn_bundled_infer_model # must be imported at here, or it will cause fairseq crash on training
self.model = spawn_bundled_infer_model(self.device)
def compute_f0(

View File

@@ -192,7 +192,7 @@ class PosteriorEncoder(nn.Module):
def __call__(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
super().__call__(x, x_lengths, g=g)
return super().__call__(x, x_lengths, g=g)
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None

View File

@@ -20,9 +20,9 @@ def slice_on_last_dim(
start_indices: List[int],
segment_size=4,
) -> torch.Tensor:
new_shape = x.shape
new_shape = [*x.shape]
new_shape[-1] = segment_size
ret = torch.empty(new_shape)
ret = torch.empty(new_shape, device=x.device)
for i in range(x.size(0)):
idx_str = start_indices[i]
idx_end = idx_str + segment_size