mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-08 20:10:44 +08:00
fix(train): mysterious importing order
This commit is contained in:
@@ -1,8 +1,7 @@
|
|||||||
from typing import Any, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torchfcpe import spawn_bundled_infer_model
|
|
||||||
|
|
||||||
from .f0 import F0Predictor
|
from .f0 import F0Predictor
|
||||||
|
|
||||||
@@ -24,6 +23,8 @@ class FCPE(F0Predictor):
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from torchfcpe import spawn_bundled_infer_model # must be imported at here, or it will cause fairseq crash on training
|
||||||
|
|
||||||
self.model = spawn_bundled_infer_model(self.device)
|
self.model = spawn_bundled_infer_model(self.device)
|
||||||
|
|
||||||
def compute_f0(
|
def compute_f0(
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class PosteriorEncoder(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
super().__call__(x, x_lengths, g=g)
|
return super().__call__(x, x_lengths, g=g)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ def slice_on_last_dim(
|
|||||||
start_indices: List[int],
|
start_indices: List[int],
|
||||||
segment_size=4,
|
segment_size=4,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
new_shape = x.shape
|
new_shape = [*x.shape]
|
||||||
new_shape[-1] = segment_size
|
new_shape[-1] = segment_size
|
||||||
ret = torch.empty(new_shape)
|
ret = torch.empty(new_shape, device=x.device)
|
||||||
for i in range(x.size(0)):
|
for i in range(x.size(0)):
|
||||||
idx_str = start_indices[i]
|
idx_str = start_indices[i]
|
||||||
idx_end = idx_str + segment_size
|
idx_end = idx_str + segment_size
|
||||||
|
|||||||
Reference in New Issue
Block a user