From fe11be3c940abda2ca5cd2aafe240fe90257c67c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 1 Jan 2025 00:23:16 +0900 Subject: [PATCH] fix(train): matplotlib deprecation (#103) --- infer/lib/train/utils.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/infer/lib/train/utils.py b/infer/lib/train/utils.py index 761bce6..885668b 100644 --- a/infer/lib/train/utils.py +++ b/infer/lib/train/utils.py @@ -5,7 +5,6 @@ import logging import os import sys from copy import deepcopy -import math import codecs import numpy as np @@ -17,6 +16,13 @@ MATPLOTLIB_FLAG = False logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logger = logging +logging.getLogger("numba").setLevel(logging.WARNING) +logging.getLogger("markdown_it").setLevel(logging.WARNING) +logging.getLogger("urllib3").setLevel(logging.WARNING) +logging.getLogger("matplotlib").setLevel(logging.WARNING) + +import matplotlib.pylab as plt + def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): assert os.path.isfile(checkpoint_path) @@ -125,8 +131,6 @@ def plot_spectrogram_to_numpy(spectrogram): MATPLOTLIB_FLAG = True mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np fig, ax = plt.subplots(figsize=(10, 2)) im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") @@ -136,8 +140,12 @@ def plot_spectrogram_to_numpy(spectrogram): plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + try: + data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道(RGB) + except: + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data @@ -151,8 +159,6 @@ def plot_alignment_to_numpy(alignment, info=None): MATPLOTLIB_FLAG = True mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow( @@ -167,8 +173,12 @@ def plot_alignment_to_numpy(alignment, info=None): plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + try: + data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道(RGB) + except: + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data