import os import json from tools.parse_markdown import parse_markdown, MarkdownNode from tools.openai_api import generate_json_via_llm from prompt.base import create_dataset from config.llm import load_config from tqdm import tqdm def process_markdown_file(file_path): with open(file_path, 'r', encoding='utf-8') as f: content = f.read() root = parse_markdown(content) results = [] def traverse(node, parent_titles): current_titles = parent_titles.copy() current_titles.append(node.title) if not node.children: # 叶子节点 if node.content: full_text = ' -> '.join(current_titles) + '\n' + node.content results.append(full_text) else: for child in node.children: traverse(child, current_titles) traverse(root, []) return results def find_markdown_files(directory): markdown_files = [] for root, dirs, files in os.walk(directory): for file in files: if file.endswith('.md'): markdown_files.append(os.path.join(root, file)) return markdown_files def process_all_markdown(doc_dir): all_results = [] markdown_files = find_markdown_files(doc_dir) for file_path in markdown_files: results = process_markdown_file(file_path) all_results.extend(results) return all_results def save_dataset(dataset, output_dir): os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, 'dataset.json') with open(output_path, 'w', encoding='utf-8') as f: json.dump(dataset, f, ensure_ascii=False, indent=2) if __name__ == "__main__": # 解析markdown文档 results = process_all_markdown('workdir/my_docs') # 加载LLM配置 config = load_config() dataset = [] # 使用tqdm包装外部循环以显示进度条 for content in tqdm(results, desc="生成数据集进度", unit="文档"): for _ in range(3): prompt = create_dataset.create( "LLaMA-Factory", # 项目名 content, # 文档内容 """{ "dataset":[ { "question":"", "answer":"" } ] }""" ) # 调用LLM生成JSON try: result = generate_json_via_llm( prompt=prompt, base_url=config["openai"]["base_url"], api_key=config["openai"]["api_key"], model_id=config["openai"]["model_id"] ) print(json.loads(result)["dataset"]) dataset.extend(json.loads(result)["dataset"]) except Exception as e: print(f"生成数据集时出错: {e}") # 保存数据集 save_dataset(dataset, 'workdir/dataset2') print(f"数据集已生成,共{len(dataset)}条数据")