mirror of
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-06-05 01:10:22 +08:00
fix(train): matplotlib deprecation (#103)
This commit is contained in:
@@ -5,7 +5,6 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
import math
|
||||
|
||||
import codecs
|
||||
import numpy as np
|
||||
@@ -17,6 +16,13 @@ MATPLOTLIB_FLAG = False
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
logger = logging
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
logging.getLogger("markdown_it").setLevel(logging.WARNING)
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
@@ -125,8 +131,6 @@ def plot_spectrogram_to_numpy(spectrogram):
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger("matplotlib")
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
@@ -136,8 +140,12 @@ def plot_spectrogram_to_numpy(spectrogram):
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
try:
|
||||
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道(RGB)
|
||||
except:
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
@@ -151,8 +159,6 @@ def plot_alignment_to_numpy(alignment, info=None):
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger("matplotlib")
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
im = ax.imshow(
|
||||
@@ -167,8 +173,12 @@ def plot_alignment_to_numpy(alignment, info=None):
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
try:
|
||||
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道(RGB)
|
||||
except:
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user