Compare commits

...

4 Commits

Author SHA1 Message Date
carry
148f4afb25 fix(main): 修复unsloth没有最先导入的问题,删除了重复的导入语句 2025-04-14 16:31:47 +08:00
carry
11a3039775 fix(train_page): 修正模型训练保存路径 2025-04-14 16:31:00 +08:00
carry
a4289815ba build: 添加 tensorboard 依赖
- 在 requirements.txt 中添加 tensorboard>=2.19.0
- 此改动增加了 tensorboard 作为项目的新依赖项
2025-04-14 16:30:39 +08:00
carry
088067d335 train: 更新模型训练功能和日志记录方式
- 修改训练目录结构,将检查点和日志分开保存
- 添加 TensorBoard 日志记录支持
- 移除自定义 LossCallback 类,简化训练流程
- 更新训练参数和回调机制,提高代码可读性
- 在 requirements.txt 中添加 tensorboardX 依赖
2025-04-14 16:19:37 +08:00
4 changed files with 15 additions and 31 deletions

View File

@ -45,31 +45,11 @@ def train_page():
# 加载数据集
dataset = get_datasets().get(Query().name == dataset_name)
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
class LossCallback(TrainerCallback):
def __init__(self):
self.loss_data = []
self.log_text = ""
self.last_output = {"text": "", "plot": None}
def on_log(self, args, state, control, logs=None, **kwargs):
if "loss" in logs:
self.loss_data.append({
"step": state.global_step,
"loss": float(logs["loss"])
})
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
# 添加以下两行print语句
print(f"Current Step: {state.global_step}")
print(f"Loss Value: {logs['loss']:.4f}")
self.last_output = {
"text": self.log_text,
}
# 不返回 control避免干预训练过程
train_model(get_model(), get_tokenizer(),
dataset, get_workdir()+"/checkpoint",
dataset, get_workdir() + "/training" + "/1",
learning_rate, per_device_train_batch_size, epoch,
save_steps, lora_rank, LossCallback)
save_steps, lora_rank)
train_button.click(
@ -80,7 +60,7 @@ def train_page():
per_device_train_batch_size_input,
epoch_input,
save_steps_input,
lora_rank_input # 新增lora_rank_input
lora_rank_input
],
outputs=output
)

View File

@ -1,5 +1,5 @@
import gradio as gr
from frontend.setting_page import setting_page
import unsloth
from frontend import *
from db import initialize_sqlite_db, initialize_prompt_store
from global_var import init_global_var, get_sql_engine, get_prompt_store

View File

@ -6,4 +6,6 @@ langchain>=0.3
tinydb>=4.0.0
unsloth>=2025.3.19
sqlmodel>=0.0.24
jinja2>=3.1.0
jinja2>=3.1.0
tensorboardX>=2.6.2.2
tensorboard>=2.19.0

View File

@ -35,13 +35,13 @@ def train_model(
model,
tokenizer,
dataset: list,
output_dir: str,
train_dir: str,
learning_rate: float,
per_device_train_batch_size: int,
epoch: int,
save_steps: int,
lora_rank: int,
trainer_callback
trainer_callback=None
) -> None:
# 模型配置参数
dtype = None # 数据类型None表示自动选择
@ -113,14 +113,16 @@ def train_model(
optim="adamw_8bit", # 使用8位AdamW优化器节省显存几乎不影响训练效果
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
seed=114514, # 随机数种子
output_dir=output_dir, # 保存模型检查点和训练日志
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
save_strategy="steps", # 按步保存中间权重
save_steps=save_steps, # 使用动态传入的保存步数
report_to="none",
logging_dir=train_dir + "/logs", # 日志文件存储路径
report_to="tensorboard", # 使用TensorBoard记录日志
),
)
trainer.add_callback(trainer_callback)
if trainer_callback is not None:
trainer.add_callback(trainer_callback)
trainer = train_on_responses_only(
trainer,