1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 17:20:25 +08:00

fix(fairseq): hubert load model error

This commit is contained in:
源文雨
2026-04-18 19:04:13 +08:00
parent 8ded36e9e1
commit f9ae0b5d32
12 changed files with 101 additions and 151 deletions

View File

@@ -29,10 +29,12 @@ try:
GradScaler = gradscaler_init()
ipex_init()
else:
from torch.cuda.amp import GradScaler, autocast
except Exception:
from torch.cuda.amp import GradScaler, autocast
pass
finally:
if not ('GradScaler' in globals() and 'autocast' in globals()):
from torch.amp.grad_scaler import GradScaler
from torch.amp.autocast_mode import autocast
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
@@ -535,7 +537,7 @@ def train_and_evaluate(
# wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
# Calculate
with autocast(enabled=hps.train.fp16_run):
with autocast(device_type="cuda", enabled=hps.train.fp16_run):
(
y_hat,
ids_slice,
@@ -554,7 +556,7 @@ def train_and_evaluate(
y_mel = slice_on_last_dim(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
with autocast(enabled=False):
with autocast(device_type="cuda", enabled=False):
y_hat_mel = mel_spectrogram_torch(
y_hat.float().squeeze(1),
hps.data.filter_length,
@@ -573,7 +575,7 @@ def train_and_evaluate(
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
with autocast(enabled=False):
with autocast(device_type="cuda", enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
@@ -583,10 +585,10 @@ def train_and_evaluate(
grad_norm_d = total_grad_norm(net_d.parameters())
scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run):
with autocast(device_type="cuda", enabled=hps.train.fp16_run):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
with autocast(enabled=False):
with autocast(device_type="cuda", enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)