diff --git a/infer/lib/train/utils.py b/infer/lib/train/utils.py index 885668b..9f50ca7 100644 --- a/infer/lib/train/utils.py +++ b/infer/lib/train/utils.py @@ -142,7 +142,9 @@ def plot_spectrogram_to_numpy(spectrogram): fig.canvas.draw() try: data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8) - data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道(RGB) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[ + :, :, :3 + ] # 只取前三个通道(RGB) except: data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) @@ -175,7 +177,9 @@ def plot_alignment_to_numpy(alignment, info=None): fig.canvas.draw() try: data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8) - data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道(RGB) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[ + :, :, :3 + ] # 只取前三个通道(RGB) except: data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))