fix(tools): 修复了optional字段无法被解析的问题
This commit is contained in:
parent
86bcf90c66
commit
e16882953d
@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel, create_model
|
||||
from typing import Any, Dict, List, Optional, get_args, get_origin
|
||||
from typing import Any, Dict, List, Optional, Union, get_args, get_origin
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
|
||||
@ -13,20 +13,15 @@ def generate_example_json(model: type[BaseModel]) -> str:
|
||||
args = get_args(field_type)
|
||||
|
||||
if origin is list or origin is List:
|
||||
if args:
|
||||
return [_generate_example(args[0])]
|
||||
else:
|
||||
return []
|
||||
return [_generate_example(args[0])] if args else []
|
||||
elif origin is dict or origin is Dict:
|
||||
if len(args) == 2 and args[0] is str:
|
||||
if len(args) == 2:
|
||||
return {"key": _generate_example(args[1])}
|
||||
else:
|
||||
return {}
|
||||
elif origin is Optional or origin is type(None):
|
||||
if args:
|
||||
return _generate_example(args[0])
|
||||
else:
|
||||
return None
|
||||
return {}
|
||||
elif origin is Union:
|
||||
# 处理 Optional 类型(Union[T, None])
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
return _generate_example(non_none_args[0]) if non_none_args else None
|
||||
elif field_type is str:
|
||||
return "string"
|
||||
elif field_type is int:
|
||||
@ -39,10 +34,18 @@ def generate_example_json(model: type[BaseModel]) -> str:
|
||||
return datetime.now().isoformat()
|
||||
elif field_type is date:
|
||||
return date.today().isoformat()
|
||||
elif issubclass(field_type, BaseModel):
|
||||
return generate_example_json(field_type)
|
||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_example_json(field_type))
|
||||
else:
|
||||
return "unknown" # 对于未知类型返回 "unknown"
|
||||
# 处理直接类型注解(非泛型)
|
||||
if field_type is type(None):
|
||||
return None
|
||||
try:
|
||||
if issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_example_json(field_type))
|
||||
except TypeError:
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
example_data = {}
|
||||
for field_name, field in model.model_fields.items():
|
||||
@ -55,9 +58,7 @@ if __name__ == "__main__":
|
||||
from pathlib import Path
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import Q_A
|
||||
class Q_A_list(BaseModel):
|
||||
Q_As: List[Q_A]
|
||||
from schema import dataset
|
||||
|
||||
print("示例 JSON:")
|
||||
print(generate_example_json(Q_A_list))
|
||||
print(generate_example_json(dataset))
|
||||
|
Loading…
x
Reference in New Issue
Block a user