gzhu-biyesheji/paper/object/plot_training_data.py
carry b4da95e81e style(paper/object): 调整图表尺寸以提高可读性
将图表尺寸从 (12, 9) 调整为 (12, 12),以便更好地展示训练数据的细节
2025-04-29 19:00:38 +08:00

44 lines
1.2 KiB
Python

import pandas as pd
import matplotlib.pyplot as plt
# 设置全局字体大小
plt.rcParams.update({
'font.size': 16, # 全局字体大小
'axes.titlesize': 20, # 标题字体大小
'axes.labelsize': 16, # 坐标轴标签字体大小
'xtick.labelsize': 14, # x轴刻度标签字体大小
'ytick.labelsize': 14, # y轴刻度标签字体大小
'legend.fontsize': 14, # 图例字体大小
})
# 读取CSV文件
data = pd.read_csv('training_data.csv')
# 创建图表
plt.figure(figsize=(12, 12))
# 绘制梯度范数变化曲线
plt.subplot(3, 1, 1) # 修改为 3行1列的第1个
plt.plot(data['Step'], data['grad_norm'], label='Gradient Norm')
plt.xlabel('Step')
plt.ylabel('Gradient Norm')
plt.legend()
# 绘制损失值变化曲线
plt.subplot(3, 1, 2) # 修改为 3行1列的第2个
plt.plot(data['Step'], data['loss'], label='Loss', color='orange')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.legend()
# 绘制学习率变化曲线
plt.subplot(3, 1, 3) # 修改为 3行1列的第3个
plt.plot(data['Step'], data['learning_rate'], label='Learning Rate', color='green')
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.legend()
# 调整布局并保存图片
plt.tight_layout()
plt.savefig('training_metrics.png')
plt.show()