1
0
mirror of https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git synced 2026-06-05 09:10:25 +08:00

fix(train): matplotlib deprecation (#103)

This commit is contained in:
源文雨
2025-01-01 00:23:16 +09:00
parent 89f7fa25cc
commit fe11be3c94

View File

@@ -5,7 +5,6 @@ import logging
import os import os
import sys import sys
from copy import deepcopy from copy import deepcopy
import math
import codecs import codecs
import numpy as np import numpy as np
@@ -17,6 +16,13 @@ MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging logger = logging
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
import matplotlib.pylab as plt
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path) assert os.path.isfile(checkpoint_path)
@@ -125,8 +131,6 @@ def plot_spectrogram_to_numpy(spectrogram):
MATPLOTLIB_FLAG = True MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib") mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING) mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2)) fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
@@ -136,8 +140,12 @@ def plot_spectrogram_to_numpy(spectrogram):
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") try:
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道RGB
except:
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close() plt.close()
return data return data
@@ -151,8 +159,6 @@ def plot_alignment_to_numpy(alignment, info=None):
MATPLOTLIB_FLAG = True MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib") mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING) mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4)) fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow( im = ax.imshow(
@@ -167,8 +173,12 @@ def plot_alignment_to_numpy(alignment, info=None):
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") try:
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道RGB
except:
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close() plt.close()
return data return data