-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathgenerate_schema.py
53 lines (41 loc) · 1.78 KB
/
generate_schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import json
from dataclasses import fields
from typing import Any, Optional, Type
from pydantic import BaseModel, Field, create_model
from pydantic.json_schema import model_json_schema
from promptim.trainer import Config
def get_schema(cls: Type[Any]) -> dict:
"""Create a JSON schema dict from a dataclass or Pydantic model.
Args:
cls: A dataclass or Pydantic model type.
Returns:
A dict representing the JSON schema of the input class.
Raises:
TypeError: If the input is not a dataclass or Pydantic model.
"""
if isinstance(cls, type) and issubclass(cls, BaseModel):
return model_json_schema(cls)
elif hasattr(cls, "__dataclass_fields__"):
# Convert dataclass to Pydantic model
fields_dict = {}
for field in fields(cls):
field_info = {}
if field.default is not field.default_factory:
# Field has a default value or default factory
field_info["default"] = field.default
if field.metadata.get("description"):
field_info["description"] = field.metadata["description"]
if field_info:
fields_dict[field.name] = (Optional[field.type], Field(**field_info))
else:
# Field is required
fields_dict[field.name] = (field.type, ...)
pydantic_model = create_model(cls.__name__, **fields_dict)
return model_json_schema(pydantic_model)
else:
raise TypeError("Input must be a dataclass or Pydantic model")
config_schema = get_schema(Config)
config_schema["$schema"] = "http://json-schema.org/draft-07/schema#"
with open("config-schema.json", "w") as f:
json.dump(config_schema, f, indent=2)
print("Schema generated and saved to config-schema.json")