Compare commits
4 Commits
9fb31c46c8
...
148f4afb25
Author | SHA1 | Date | |
---|---|---|---|
![]() |
148f4afb25 | ||
![]() |
11a3039775 | ||
![]() |
a4289815ba | ||
![]() |
088067d335 |
@ -46,30 +46,10 @@ def train_page():
|
|||||||
dataset = get_datasets().get(Query().name == dataset_name)
|
dataset = get_datasets().get(Query().name == dataset_name)
|
||||||
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
||||||
|
|
||||||
class LossCallback(TrainerCallback):
|
|
||||||
def __init__(self):
|
|
||||||
self.loss_data = []
|
|
||||||
self.log_text = ""
|
|
||||||
self.last_output = {"text": "", "plot": None}
|
|
||||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
||||||
if "loss" in logs:
|
|
||||||
self.loss_data.append({
|
|
||||||
"step": state.global_step,
|
|
||||||
"loss": float(logs["loss"])
|
|
||||||
})
|
|
||||||
self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n"
|
|
||||||
# 添加以下两行print语句
|
|
||||||
print(f"Current Step: {state.global_step}")
|
|
||||||
print(f"Loss Value: {logs['loss']:.4f}")
|
|
||||||
self.last_output = {
|
|
||||||
"text": self.log_text,
|
|
||||||
}
|
|
||||||
# 不返回 control,避免干预训练过程
|
|
||||||
|
|
||||||
train_model(get_model(), get_tokenizer(),
|
train_model(get_model(), get_tokenizer(),
|
||||||
dataset, get_workdir()+"/checkpoint",
|
dataset, get_workdir() + "/training" + "/1",
|
||||||
learning_rate, per_device_train_batch_size, epoch,
|
learning_rate, per_device_train_batch_size, epoch,
|
||||||
save_steps, lora_rank, LossCallback)
|
save_steps, lora_rank)
|
||||||
|
|
||||||
|
|
||||||
train_button.click(
|
train_button.click(
|
||||||
@ -80,7 +60,7 @@ def train_page():
|
|||||||
per_device_train_batch_size_input,
|
per_device_train_batch_size_input,
|
||||||
epoch_input,
|
epoch_input,
|
||||||
save_steps_input,
|
save_steps_input,
|
||||||
lora_rank_input # 新增lora_rank_input
|
lora_rank_input
|
||||||
],
|
],
|
||||||
outputs=output
|
outputs=output
|
||||||
)
|
)
|
||||||
|
2
main.py
2
main.py
@ -1,5 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from frontend.setting_page import setting_page
|
import unsloth
|
||||||
from frontend import *
|
from frontend import *
|
||||||
from db import initialize_sqlite_db, initialize_prompt_store
|
from db import initialize_sqlite_db, initialize_prompt_store
|
||||||
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
||||||
|
@ -7,3 +7,5 @@ tinydb>=4.0.0
|
|||||||
unsloth>=2025.3.19
|
unsloth>=2025.3.19
|
||||||
sqlmodel>=0.0.24
|
sqlmodel>=0.0.24
|
||||||
jinja2>=3.1.0
|
jinja2>=3.1.0
|
||||||
|
tensorboardX>=2.6.2.2
|
||||||
|
tensorboard>=2.19.0
|
@ -35,13 +35,13 @@ def train_model(
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
dataset: list,
|
dataset: list,
|
||||||
output_dir: str,
|
train_dir: str,
|
||||||
learning_rate: float,
|
learning_rate: float,
|
||||||
per_device_train_batch_size: int,
|
per_device_train_batch_size: int,
|
||||||
epoch: int,
|
epoch: int,
|
||||||
save_steps: int,
|
save_steps: int,
|
||||||
lora_rank: int,
|
lora_rank: int,
|
||||||
trainer_callback
|
trainer_callback=None
|
||||||
) -> None:
|
) -> None:
|
||||||
# 模型配置参数
|
# 模型配置参数
|
||||||
dtype = None # 数据类型,None表示自动选择
|
dtype = None # 数据类型,None表示自动选择
|
||||||
@ -113,14 +113,16 @@ def train_model(
|
|||||||
optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果
|
optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果
|
||||||
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
|
weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合
|
||||||
seed=114514, # 随机数种子
|
seed=114514, # 随机数种子
|
||||||
output_dir=output_dir, # 保存模型检查点和训练日志
|
output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志
|
||||||
save_strategy="steps", # 按步保存中间权重
|
save_strategy="steps", # 按步保存中间权重
|
||||||
save_steps=save_steps, # 使用动态传入的保存步数
|
save_steps=save_steps, # 使用动态传入的保存步数
|
||||||
report_to="none",
|
logging_dir=train_dir + "/logs", # 日志文件存储路径
|
||||||
|
report_to="tensorboard", # 使用TensorBoard记录日志
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.add_callback(trainer_callback)
|
if trainer_callback is not None:
|
||||||
|
trainer.add_callback(trainer_callback)
|
||||||
|
|
||||||
trainer = train_on_responses_only(
|
trainer = train_on_responses_only(
|
||||||
trainer,
|
trainer,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user