mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 09:10:25 +08:00
chore(format): run black on dev (#9)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
96604e8175
commit
44725ddd2c
@@ -173,7 +173,10 @@ class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# Concat columns of pad to shift from relative to absolute indexing.
|
||||
x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0], )
|
||||
x = F.pad(
|
||||
x,
|
||||
[0, 1, 0, 0, 0, 0, 0, 0],
|
||||
)
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
|
||||
@@ -57,7 +57,7 @@ class Encoder(nn.Module):
|
||||
)
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
|
||||
def __call__(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
|
||||
return super().__call__(x, x_mask)
|
||||
|
||||
@@ -146,7 +146,8 @@ class TextEncoder(nn.Module):
|
||||
x = self.lrelu(x)
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(
|
||||
sequence_mask(lengths, x.size(2)), 1,
|
||||
sequence_mask(lengths, x.size(2)),
|
||||
1,
|
||||
).to(x.dtype)
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
"""
|
||||
|
||||
@@ -66,13 +66,15 @@ def sequence_mask(
|
||||
|
||||
|
||||
def total_grad_norm(
|
||||
parameters: Iterator[torch.nn.Parameter], norm_type: float=2.0,
|
||||
parameters: Iterator[torch.nn.Parameter],
|
||||
norm_type: float = 2.0,
|
||||
) -> float:
|
||||
norm_type = float(norm_type)
|
||||
total_norm = 0.0
|
||||
|
||||
for p in parameters:
|
||||
if p.grad is None: continue
|
||||
if p.grad is None:
|
||||
continue
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += float(param_norm.item()) ** norm_type
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
|
||||
Reference in New Issue
Block a user