mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 17:20:25 +08:00
674 lines
23 KiB
Python
674 lines
23 KiB
Python
from typing import Optional, List
|
||
|
||
import torch
|
||
from torch import nn
|
||
from torch.nn import Conv1d, Conv2d
|
||
from torch.nn import functional as F
|
||
from torch.nn.utils import spectral_norm, weight_norm
|
||
from rvc import residuals
|
||
|
||
from rvc.residuals import ResidualCouplingBlock
|
||
from rvc.utils import (
|
||
get_padding,
|
||
slice_on_last_dim,
|
||
rand_slice_segments_on_last_dim,
|
||
)
|
||
from rvc.encoders import TextEncoder, PosteriorEncoder
|
||
from rvc.generators import Generator
|
||
from rvc.nsf import NSFGenerator
|
||
|
||
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
|
||
|
||
|
||
class SynthesizerTrnMsNSFsid(nn.Module):
|
||
def __init__(
|
||
self,
|
||
spec_channels: int,
|
||
segment_size: int,
|
||
inter_channels: int,
|
||
hidden_channels: int,
|
||
filter_channels: int,
|
||
n_heads: int,
|
||
n_layers: int,
|
||
kernel_size: int,
|
||
p_dropout: int,
|
||
resblock: str,
|
||
resblock_kernel_sizes: List[int],
|
||
resblock_dilation_sizes: List[List[int]],
|
||
upsample_rates: List[int],
|
||
upsample_initial_channel: int,
|
||
upsample_kernel_sizes: List[int],
|
||
spk_embed_dim: int,
|
||
gin_channels: int,
|
||
sr: str | int,
|
||
text_encoder_in_channels: int,
|
||
):
|
||
super(SynthesizerTrnMs256NSFsid, self).__init__()
|
||
if isinstance(sr, str):
|
||
sr = {
|
||
"32k": 32000,
|
||
"40k": 40000,
|
||
"48k": 48000,
|
||
}[sr]
|
||
self.spec_channels = spec_channels
|
||
self.inter_channels = inter_channels
|
||
self.hidden_channels = hidden_channels
|
||
self.filter_channels = filter_channels
|
||
self.n_heads = n_heads
|
||
self.n_layers = n_layers
|
||
self.kernel_size = kernel_size
|
||
self.p_dropout = float(p_dropout)
|
||
self.resblock = resblock
|
||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||
self.upsample_rates = upsample_rates
|
||
self.upsample_initial_channel = upsample_initial_channel
|
||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||
self.segment_size = segment_size
|
||
self.gin_channels = gin_channels
|
||
# self.hop_length = hop_length#
|
||
self.spk_embed_dim = spk_embed_dim
|
||
self.enc_p = TextEncoder(
|
||
text_encoder_in_channels,
|
||
inter_channels,
|
||
hidden_channels,
|
||
filter_channels,
|
||
n_heads,
|
||
n_layers,
|
||
kernel_size,
|
||
float(p_dropout),
|
||
)
|
||
self.dec = NSFGenerator(
|
||
inter_channels,
|
||
resblock,
|
||
resblock_kernel_sizes,
|
||
resblock_dilation_sizes,
|
||
upsample_rates,
|
||
upsample_initial_channel,
|
||
upsample_kernel_sizes,
|
||
gin_channels=gin_channels,
|
||
sr=sr,
|
||
)
|
||
self.enc_q = PosteriorEncoder(
|
||
spec_channels,
|
||
inter_channels,
|
||
hidden_channels,
|
||
5,
|
||
1,
|
||
16,
|
||
gin_channels=gin_channels,
|
||
)
|
||
self.flow = ResidualCouplingBlock(
|
||
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
||
)
|
||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
||
|
||
def remove_weight_norm(self):
|
||
self.dec.remove_weight_norm()
|
||
self.flow.remove_weight_norm()
|
||
if hasattr(self, "enc_q"):
|
||
self.enc_q.remove_weight_norm()
|
||
|
||
def __prepare_scriptable__(self):
|
||
for hook in self.dec._forward_pre_hooks.values():
|
||
# The hook we want to remove is an instance of WeightNorm class, so
|
||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||
# because of shadowing, so we check the module name directly.
|
||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||
if (
|
||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||
and hook.__class__.__name__ == "WeightNorm"
|
||
):
|
||
torch.nn.utils.remove_weight_norm(self.dec)
|
||
for hook in self.flow._forward_pre_hooks.values():
|
||
if (
|
||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||
and hook.__class__.__name__ == "WeightNorm"
|
||
):
|
||
torch.nn.utils.remove_weight_norm(self.flow)
|
||
if hasattr(self, "enc_q"):
|
||
for hook in self.enc_q._forward_pre_hooks.values():
|
||
if (
|
||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||
and hook.__class__.__name__ == "WeightNorm"
|
||
):
|
||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||
return self
|
||
|
||
@torch.jit.ignore
|
||
def forward(
|
||
self,
|
||
phone: torch.Tensor,
|
||
phone_lengths: torch.Tensor,
|
||
pitch: torch.Tensor,
|
||
pitchf: torch.Tensor,
|
||
y: torch.Tensor,
|
||
y_lengths: torch.Tensor,
|
||
ds: Optional[torch.Tensor] = None,
|
||
): # 这里ds是id,[bs,1]
|
||
# print(1,pitch.shape)#[bs,t]
|
||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||
z_p = self.flow(z, y_mask, g=g)
|
||
z_slice, ids_slice = rand_slice_segments_on_last_dim(
|
||
z, y_lengths, self.segment_size
|
||
)
|
||
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
|
||
pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
|
||
# print(-2,pitchf.shape,z_slice.shape)
|
||
o = self.dec(z_slice, pitchf, g=g)
|
||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||
|
||
@torch.jit.export
|
||
def infer(
|
||
self,
|
||
phone: torch.Tensor,
|
||
phone_lengths: torch.Tensor,
|
||
pitch: torch.Tensor,
|
||
nsff0: torch.Tensor,
|
||
sid: torch.Tensor,
|
||
skip_head: Optional[torch.Tensor] = None,
|
||
return_length: Optional[torch.Tensor] = None,
|
||
# return_length2: Optional[torch.Tensor] = None,
|
||
):
|
||
g = self.emb_g(sid).unsqueeze(-1)
|
||
if skip_head is not None and return_length is not None:
|
||
assert isinstance(skip_head, torch.Tensor)
|
||
assert isinstance(return_length, torch.Tensor)
|
||
head = int(skip_head.item())
|
||
length = int(return_length.item())
|
||
flow_head = torch.clamp(skip_head - 24, min=0)
|
||
dec_head = head - int(flow_head.item())
|
||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head)
|
||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||
z = z[:, :, dec_head : dec_head + length]
|
||
x_mask = x_mask[:, :, dec_head : dec_head + length]
|
||
nsff0 = nsff0[:, head : head + length]
|
||
else:
|
||
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||
o = self.dec(
|
||
z * x_mask,
|
||
nsff0,
|
||
g=g,
|
||
# n_res=return_length2,
|
||
)
|
||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||
|
||
|
||
class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
|
||
def __init__(
|
||
self,
|
||
spec_channels: int,
|
||
segment_size: int,
|
||
inter_channels: int,
|
||
hidden_channels: int,
|
||
filter_channels: int,
|
||
n_heads: int,
|
||
n_layers: int,
|
||
kernel_size: int,
|
||
p_dropout: int,
|
||
resblock: str,
|
||
resblock_kernel_sizes: List[int],
|
||
resblock_dilation_sizes: List[List[int]],
|
||
upsample_rates: List[int],
|
||
upsample_initial_channel: int,
|
||
upsample_kernel_sizes: List[int],
|
||
spk_embed_dim: int,
|
||
gin_channels: int,
|
||
sr: str | int,
|
||
):
|
||
super().__init__(
|
||
spec_channels,
|
||
segment_size,
|
||
inter_channels,
|
||
hidden_channels,
|
||
filter_channels,
|
||
n_heads,
|
||
n_layers,
|
||
kernel_size,
|
||
p_dropout,
|
||
resblock,
|
||
resblock_kernel_sizes,
|
||
resblock_dilation_sizes,
|
||
upsample_rates,
|
||
upsample_initial_channel,
|
||
upsample_kernel_sizes,
|
||
spk_embed_dim,
|
||
gin_channels,
|
||
sr,
|
||
256,
|
||
)
|
||
|
||
|
||
class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid):
|
||
def __init__(
|
||
self,
|
||
spec_channels: int,
|
||
segment_size: int,
|
||
inter_channels: int,
|
||
hidden_channels: int,
|
||
filter_channels: int,
|
||
n_heads: int,
|
||
n_layers: int,
|
||
kernel_size: int,
|
||
p_dropout: int,
|
||
resblock: str,
|
||
resblock_kernel_sizes: List[int],
|
||
resblock_dilation_sizes: List[List[int]],
|
||
upsample_rates: List[int],
|
||
upsample_initial_channel: int,
|
||
upsample_kernel_sizes: List[int],
|
||
spk_embed_dim: int,
|
||
gin_channels: int,
|
||
sr: str | int,
|
||
):
|
||
super().__init__(
|
||
spec_channels,
|
||
segment_size,
|
||
inter_channels,
|
||
hidden_channels,
|
||
filter_channels,
|
||
n_heads,
|
||
n_layers,
|
||
kernel_size,
|
||
p_dropout,
|
||
resblock,
|
||
resblock_kernel_sizes,
|
||
resblock_dilation_sizes,
|
||
upsample_rates,
|
||
upsample_initial_channel,
|
||
upsample_kernel_sizes,
|
||
spk_embed_dim,
|
||
gin_channels,
|
||
sr,
|
||
768,
|
||
)
|
||
|
||
|
||
class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
||
def __init__(
|
||
self,
|
||
spec_channels: int,
|
||
segment_size: int,
|
||
inter_channels: int,
|
||
hidden_channels: int,
|
||
filter_channels: int,
|
||
n_heads: int,
|
||
n_layers: int,
|
||
kernel_size: int,
|
||
p_dropout: int,
|
||
resblock: str,
|
||
resblock_kernel_sizes: List[int],
|
||
resblock_dilation_sizes: List[List[int]],
|
||
upsample_rates: List[int],
|
||
upsample_initial_channel: int,
|
||
upsample_kernel_sizes: List[int],
|
||
spk_embed_dim: int,
|
||
gin_channels: int,
|
||
sr=None,
|
||
):
|
||
super(SynthesizerTrnMs256NSFsid_nono, self).__init__()
|
||
self.spec_channels = spec_channels
|
||
self.inter_channels = inter_channels
|
||
self.hidden_channels = hidden_channels
|
||
self.filter_channels = filter_channels
|
||
self.n_heads = n_heads
|
||
self.n_layers = n_layers
|
||
self.kernel_size = kernel_size
|
||
self.p_dropout = float(p_dropout)
|
||
self.resblock = resblock
|
||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||
self.upsample_rates = upsample_rates
|
||
self.upsample_initial_channel = upsample_initial_channel
|
||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||
self.segment_size = segment_size
|
||
self.gin_channels = gin_channels
|
||
# self.hop_length = hop_length#
|
||
self.spk_embed_dim = spk_embed_dim
|
||
self.enc_p = TextEncoder(
|
||
256,
|
||
inter_channels,
|
||
hidden_channels,
|
||
filter_channels,
|
||
n_heads,
|
||
n_layers,
|
||
kernel_size,
|
||
float(p_dropout),
|
||
f0=False,
|
||
)
|
||
self.dec = Generator(
|
||
inter_channels,
|
||
resblock,
|
||
resblock_kernel_sizes,
|
||
resblock_dilation_sizes,
|
||
upsample_rates,
|
||
upsample_initial_channel,
|
||
upsample_kernel_sizes,
|
||
gin_channels=gin_channels,
|
||
)
|
||
self.enc_q = PosteriorEncoder(
|
||
spec_channels,
|
||
inter_channels,
|
||
hidden_channels,
|
||
5,
|
||
1,
|
||
16,
|
||
gin_channels=gin_channels,
|
||
)
|
||
self.flow = ResidualCouplingBlock(
|
||
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
||
)
|
||
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
||
|
||
def remove_weight_norm(self):
|
||
self.dec.remove_weight_norm()
|
||
self.flow.remove_weight_norm()
|
||
if hasattr(self, "enc_q"):
|
||
self.enc_q.remove_weight_norm()
|
||
|
||
def __prepare_scriptable__(self):
|
||
for hook in self.dec._forward_pre_hooks.values():
|
||
# The hook we want to remove is an instance of WeightNorm class, so
|
||
# normally we would do `if isinstance(...)` but this class is not accessible
|
||
# because of shadowing, so we check the module name directly.
|
||
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
||
if (
|
||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||
and hook.__class__.__name__ == "WeightNorm"
|
||
):
|
||
torch.nn.utils.remove_weight_norm(self.dec)
|
||
for hook in self.flow._forward_pre_hooks.values():
|
||
if (
|
||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||
and hook.__class__.__name__ == "WeightNorm"
|
||
):
|
||
torch.nn.utils.remove_weight_norm(self.flow)
|
||
if hasattr(self, "enc_q"):
|
||
for hook in self.enc_q._forward_pre_hooks.values():
|
||
if (
|
||
hook.__module__ == "torch.nn.utils.weight_norm"
|
||
and hook.__class__.__name__ == "WeightNorm"
|
||
):
|
||
torch.nn.utils.remove_weight_norm(self.enc_q)
|
||
return self
|
||
|
||
@torch.jit.ignore
|
||
def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
|
||
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
|
||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||
z_p = self.flow(z, y_mask, g=g)
|
||
z_slice, ids_slice = rand_slice_segments_on_last_dim(
|
||
z, y_lengths, self.segment_size
|
||
)
|
||
o = self.dec(z_slice, g=g)
|
||
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||
|
||
@torch.jit.export
|
||
def infer(
|
||
self,
|
||
phone: torch.Tensor,
|
||
phone_lengths: torch.Tensor,
|
||
sid: torch.Tensor,
|
||
skip_head: Optional[torch.Tensor] = None,
|
||
return_length: Optional[torch.Tensor] = None,
|
||
# return_length2: Optional[torch.Tensor] = None,
|
||
):
|
||
g = self.emb_g(sid).unsqueeze(-1)
|
||
if skip_head is not None and return_length is not None:
|
||
assert isinstance(skip_head, torch.Tensor)
|
||
assert isinstance(return_length, torch.Tensor)
|
||
head = int(skip_head.item())
|
||
length = int(return_length.item())
|
||
flow_head = torch.clamp(skip_head - 24, min=0)
|
||
dec_head = head - int(flow_head.item())
|
||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths, flow_head)
|
||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||
z = z[:, :, dec_head : dec_head + length]
|
||
x_mask = x_mask[:, :, dec_head : dec_head + length]
|
||
else:
|
||
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
||
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||
o = self.dec(
|
||
z * x_mask,
|
||
g=g,
|
||
# n_res=return_length2
|
||
)
|
||
return o, x_mask, (z, z_p, m_p, logs_p)
|
||
|
||
|
||
class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMs256NSFsid_nono):
|
||
def __init__(
|
||
self,
|
||
spec_channels: int,
|
||
segment_size: int,
|
||
inter_channels: int,
|
||
hidden_channels: int,
|
||
filter_channels: int,
|
||
n_heads: int,
|
||
n_layers: int,
|
||
kernel_size: int,
|
||
p_dropout: int,
|
||
resblock: str,
|
||
resblock_kernel_sizes: List[int],
|
||
resblock_dilation_sizes: List[List[int]],
|
||
upsample_rates: List[int],
|
||
upsample_initial_channel: int,
|
||
upsample_kernel_sizes: List[int],
|
||
spk_embed_dim: int,
|
||
gin_channels: int,
|
||
sr=None,
|
||
):
|
||
super(SynthesizerTrnMs768NSFsid_nono, self).__init__(
|
||
spec_channels,
|
||
segment_size,
|
||
inter_channels,
|
||
hidden_channels,
|
||
filter_channels,
|
||
n_heads,
|
||
n_layers,
|
||
kernel_size,
|
||
p_dropout,
|
||
resblock,
|
||
resblock_kernel_sizes,
|
||
resblock_dilation_sizes,
|
||
upsample_rates,
|
||
upsample_initial_channel,
|
||
upsample_kernel_sizes,
|
||
spk_embed_dim,
|
||
gin_channels,
|
||
)
|
||
del self.enc_p
|
||
self.enc_p = TextEncoder(
|
||
768,
|
||
inter_channels,
|
||
hidden_channels,
|
||
filter_channels,
|
||
n_heads,
|
||
n_layers,
|
||
kernel_size,
|
||
float(p_dropout),
|
||
f0=False,
|
||
)
|
||
|
||
|
||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||
def __init__(self, use_spectral_norm=False):
|
||
super(MultiPeriodDiscriminator, self).__init__()
|
||
periods = [2, 3, 5, 7, 11, 17]
|
||
# periods = [3, 5, 7, 11, 17, 23, 37]
|
||
|
||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||
discs = discs + [
|
||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||
]
|
||
self.discriminators = nn.ModuleList(discs)
|
||
|
||
def forward(self, y, y_hat):
|
||
y_d_rs = [] #
|
||
y_d_gs = []
|
||
fmap_rs = []
|
||
fmap_gs = []
|
||
for i, d in enumerate(self.discriminators):
|
||
y_d_r, fmap_r = d(y)
|
||
y_d_g, fmap_g = d(y_hat)
|
||
# for j in range(len(fmap_r)):
|
||
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
|
||
y_d_rs.append(y_d_r)
|
||
y_d_gs.append(y_d_g)
|
||
fmap_rs.append(fmap_r)
|
||
fmap_gs.append(fmap_g)
|
||
|
||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||
|
||
|
||
class MultiPeriodDiscriminatorV2(torch.nn.Module):
|
||
def __init__(self, use_spectral_norm=False):
|
||
super(MultiPeriodDiscriminatorV2, self).__init__()
|
||
# periods = [2, 3, 5, 7, 11, 17]
|
||
periods = [2, 3, 5, 7, 11, 17, 23, 37]
|
||
|
||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||
discs = discs + [
|
||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||
]
|
||
self.discriminators = nn.ModuleList(discs)
|
||
|
||
def forward(self, y, y_hat):
|
||
y_d_rs = [] #
|
||
y_d_gs = []
|
||
fmap_rs = []
|
||
fmap_gs = []
|
||
for i, d in enumerate(self.discriminators):
|
||
y_d_r, fmap_r = d(y)
|
||
y_d_g, fmap_g = d(y_hat)
|
||
# for j in range(len(fmap_r)):
|
||
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
|
||
y_d_rs.append(y_d_r)
|
||
y_d_gs.append(y_d_g)
|
||
fmap_rs.append(fmap_r)
|
||
fmap_gs.append(fmap_g)
|
||
|
||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||
|
||
|
||
class DiscriminatorS(torch.nn.Module):
|
||
def __init__(self, use_spectral_norm=False):
|
||
super(DiscriminatorS, self).__init__()
|
||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||
self.convs = nn.ModuleList(
|
||
[
|
||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||
]
|
||
)
|
||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||
|
||
def forward(self, x):
|
||
fmap = []
|
||
|
||
for l in self.convs:
|
||
x = l(x)
|
||
x = F.leaky_relu(x, residuals.LRELU_SLOPE)
|
||
fmap.append(x)
|
||
x = self.conv_post(x)
|
||
fmap.append(x)
|
||
x = torch.flatten(x, 1, -1)
|
||
|
||
return x, fmap
|
||
|
||
|
||
class DiscriminatorP(torch.nn.Module):
|
||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||
super(DiscriminatorP, self).__init__()
|
||
self.period = period
|
||
self.use_spectral_norm = use_spectral_norm
|
||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||
self.convs = nn.ModuleList(
|
||
[
|
||
norm_f(
|
||
Conv2d(
|
||
1,
|
||
32,
|
||
(kernel_size, 1),
|
||
(stride, 1),
|
||
padding=(get_padding(kernel_size, 1), 0),
|
||
)
|
||
),
|
||
norm_f(
|
||
Conv2d(
|
||
32,
|
||
128,
|
||
(kernel_size, 1),
|
||
(stride, 1),
|
||
padding=(get_padding(kernel_size, 1), 0),
|
||
)
|
||
),
|
||
norm_f(
|
||
Conv2d(
|
||
128,
|
||
512,
|
||
(kernel_size, 1),
|
||
(stride, 1),
|
||
padding=(get_padding(kernel_size, 1), 0),
|
||
)
|
||
),
|
||
norm_f(
|
||
Conv2d(
|
||
512,
|
||
1024,
|
||
(kernel_size, 1),
|
||
(stride, 1),
|
||
padding=(get_padding(kernel_size, 1), 0),
|
||
)
|
||
),
|
||
norm_f(
|
||
Conv2d(
|
||
1024,
|
||
1024,
|
||
(kernel_size, 1),
|
||
1,
|
||
padding=(get_padding(kernel_size, 1), 0),
|
||
)
|
||
),
|
||
]
|
||
)
|
||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||
|
||
def forward(self, x):
|
||
fmap = []
|
||
|
||
# 1d to 2d
|
||
b, c, t = x.shape
|
||
if t % self.period != 0: # pad first
|
||
n_pad = self.period - (t % self.period)
|
||
if has_xpu and x.dtype == torch.bfloat16:
|
||
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(
|
||
dtype=torch.bfloat16
|
||
)
|
||
else:
|
||
x = F.pad(x, (0, n_pad), "reflect")
|
||
t = t + n_pad
|
||
x = x.view(b, c, t // self.period, self.period)
|
||
|
||
for l in self.convs:
|
||
x = l(x)
|
||
x = F.leaky_relu(x, residuals.LRELU_SLOPE)
|
||
fmap.append(x)
|
||
x = self.conv_post(x)
|
||
fmap.append(x)
|
||
x = torch.flatten(x, 1, -1)
|
||
|
||
return x, fmap
|