gzhu-biyesheji/dataset_generator.py

94 lines
2.9 KiB
Python

import os
import json
from tools.parse_markdown import parse_markdown, MarkdownNode
from tools.openai_api import generate_json_via_llm
from prompt.base import create_dataset
from config.llm import load_config
from tqdm import tqdm
def process_markdown_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
root = parse_markdown(content)
results = []
def traverse(node, parent_titles):
current_titles = parent_titles.copy()
current_titles.append(node.title)
if not node.children: # 叶子节点
if node.content:
full_text = ' -> '.join(current_titles) + '\n' + node.content
results.append(full_text)
else:
for child in node.children:
traverse(child, current_titles)
traverse(root, [])
return results
def find_markdown_files(directory):
markdown_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.md'):
markdown_files.append(os.path.join(root, file))
return markdown_files
def process_all_markdown(doc_dir):
all_results = []
markdown_files = find_markdown_files(doc_dir)
for file_path in markdown_files:
results = process_markdown_file(file_path)
all_results.extend(results)
return all_results
def save_dataset(dataset, output_dir):
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, 'dataset.json')
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
# 解析markdown文档
results = process_all_markdown('workdir/my_docs')
# 加载LLM配置
config = load_config()
dataset = []
# 使用tqdm包装外部循环以显示进度条
for content in tqdm(results, desc="生成数据集进度", unit="文档"):
for _ in range(3):
prompt = create_dataset.create(
"LLaMA-Factory", # 项目名
content, # 文档内容
"""{
"dataset":[
{
"question":"",
"answer":""
}
]
}"""
)
# 调用LLM生成JSON
try:
result = generate_json_via_llm(
prompt=prompt,
base_url=config["openai"]["base_url"],
api_key=config["openai"]["api_key"],
model_id=config["openai"]["model_id"]
)
print(json.loads(result)["dataset"])
dataset.extend(json.loads(result)["dataset"])
except Exception as e:
print(f"生成数据集时出错: {e}")
# 保存数据集
save_dataset(dataset, 'workdir/dataset2')
print(f"数据集已生成,共{len(dataset)}条数据")