94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
import os
|
|
import json
|
|
from tools.parse_markdown import parse_markdown, MarkdownNode
|
|
from tools.openai_api import generate_json_via_llm
|
|
from prompt.base import create_dataset
|
|
from config.llm import load_config
|
|
from tqdm import tqdm
|
|
|
|
def process_markdown_file(file_path):
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
|
|
root = parse_markdown(content)
|
|
results = []
|
|
|
|
def traverse(node, parent_titles):
|
|
current_titles = parent_titles.copy()
|
|
current_titles.append(node.title)
|
|
|
|
if not node.children: # 叶子节点
|
|
if node.content:
|
|
full_text = ' -> '.join(current_titles) + '\n' + node.content
|
|
results.append(full_text)
|
|
else:
|
|
for child in node.children:
|
|
traverse(child, current_titles)
|
|
|
|
traverse(root, [])
|
|
return results
|
|
|
|
def find_markdown_files(directory):
|
|
markdown_files = []
|
|
for root, dirs, files in os.walk(directory):
|
|
for file in files:
|
|
if file.endswith('.md'):
|
|
markdown_files.append(os.path.join(root, file))
|
|
return markdown_files
|
|
|
|
def process_all_markdown(doc_dir):
|
|
all_results = []
|
|
|
|
markdown_files = find_markdown_files(doc_dir)
|
|
for file_path in markdown_files:
|
|
results = process_markdown_file(file_path)
|
|
all_results.extend(results)
|
|
|
|
return all_results
|
|
|
|
def save_dataset(dataset, output_dir):
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
output_path = os.path.join(output_dir, 'dataset.json')
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
|
|
|
if __name__ == "__main__":
|
|
# 解析markdown文档
|
|
results = process_all_markdown('workdir/my_docs')
|
|
|
|
# 加载LLM配置
|
|
config = load_config()
|
|
|
|
dataset = []
|
|
# 使用tqdm包装外部循环以显示进度条
|
|
for content in tqdm(results, desc="生成数据集进度", unit="文档"):
|
|
for _ in range(3):
|
|
prompt = create_dataset.create(
|
|
"LLaMA-Factory", # 项目名
|
|
content, # 文档内容
|
|
"""{
|
|
"dataset":[
|
|
{
|
|
"question":"",
|
|
"answer":""
|
|
}
|
|
]
|
|
}"""
|
|
)
|
|
|
|
# 调用LLM生成JSON
|
|
try:
|
|
result = generate_json_via_llm(
|
|
prompt=prompt,
|
|
base_url=config["openai"]["base_url"],
|
|
api_key=config["openai"]["api_key"],
|
|
model_id=config["openai"]["model_id"]
|
|
)
|
|
print(json.loads(result)["dataset"])
|
|
dataset.extend(json.loads(result)["dataset"])
|
|
except Exception as e:
|
|
print(f"生成数据集时出错: {e}")
|
|
|
|
# 保存数据集
|
|
save_dataset(dataset, 'workdir/dataset2')
|
|
print(f"数据集已生成,共{len(dataset)}条数据") |