1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 01:10:22 +08:00

fix(train): matplotlib deprecation (#103)

This commit is contained in:
源文雨
2025-01-01 00:23:16 +09:00
parent 89f7fa25cc
commit fe11be3c94

View File

@@ -5,7 +5,6 @@ import logging
import os
import sys
from copy import deepcopy
import math
import codecs
import numpy as np
@@ -17,6 +16,13 @@ MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
import matplotlib.pylab as plt
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
@@ -125,8 +131,6 @@ def plot_spectrogram_to_numpy(spectrogram):
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
@@ -136,8 +140,12 @@ def plot_spectrogram_to_numpy(spectrogram):
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
try:
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道RGB
except:
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
@@ -151,8 +159,6 @@ def plot_alignment_to_numpy(alignment, info=None):
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
@@ -167,8 +173,12 @@ def plot_alignment_to_numpy(alignment, info=None):
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
try:
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道RGB
except:
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data