Skip to content

Commit

Permalink
fix: add pydantic dictionary support (#2007)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonaslagoni authored Jul 27, 2024
1 parent 9f9fc78 commit 62fd08b
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/generators/python/PythonDependencyManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export class PythonDependencyManager extends AbstractDependencyManager {
const importMap: Record<string, string[]> = {};
const dependenciesToRender = [];
for (const dependency of individualDependencies) {
const regex = /from ([A-Za-z0-9]+) import ([A-Za-z0-9,\s]+)/g;
const regex = /from ([A-Za-z0-9]+) import ([A-Za-z0-9_\-,\s]+)/g;
const matches = regex.exec(dependency);

if (!matches) {
Expand Down
58 changes: 57 additions & 1 deletion src/generators/python/presets/Pydantic.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
ConstrainedDictionaryModel,
ConstrainedObjectPropertyModel,
ConstrainedUnionModel
} from '../../../models';
import { PythonOptions } from '../PythonGenerator';
Expand Down Expand Up @@ -64,7 +65,62 @@ const PYTHON_PYDANTIC_CLASS_PRESET: ClassPresetType<PythonOptions> = {
},
ctor: () => '',
getter: () => '',
setter: () => ''
setter: () => '',
additionalContent: ({ content, model, renderer }) => {
const allProperties = Object.keys(model.properties);
let dictionaryModel: ConstrainedObjectPropertyModel | undefined;
for (const property of Object.values(model.properties)) {
if (
property.property instanceof ConstrainedDictionaryModel &&
property.property.serializationType === 'unwrap'
) {
dictionaryModel = property;
}
}
const shouldHaveFunctions = dictionaryModel !== undefined;
if (!shouldHaveFunctions) {
return content;
}

renderer.dependencyManager.addDependency(
'from pydantic import model_serializer, model_validator'
);
// eslint-disable-next-line prettier/prettier
return `@model_serializer(mode='wrap')
def custom_serializer(self, handler):
serialized_self = handler(self)
${dictionaryModel?.propertyName} = getattr(self, "${dictionaryModel?.propertyName}")
if ${dictionaryModel?.propertyName} is not None:
for key, value in ${dictionaryModel?.propertyName}.items():
# Never overwrite existing values, to avoid clashes
if not hasattr(serialized_self, key):
serialized_self[key] = value
return serialized_self
@model_validator(mode='before')
@classmethod
def unwrap_${dictionaryModel?.propertyName}(cls, data):
json_properties = list(data.keys())
known_object_properties = [${allProperties
.map((value) => `'${value}'`)
.join(', ')}]
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
# Ignore attempts that validate regular models, only when unknown input is used we add unwrap extensions
if len(unknown_object_properties) == 0:
return data
known_json_properties = [${Object.values(model.properties)
.map((value) => `'${value.unconstrainedPropertyName}'`)
.join(', ')}]
${dictionaryModel?.propertyName} = {}
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
${dictionaryModel?.propertyName}[obj_key] = data.pop(obj_key, None)
data['${dictionaryModel?.unconstrainedPropertyName}'] = ${dictionaryModel?.propertyName}
return data
${content}`;
}
};

export const PYTHON_PYDANTIC_PRESET: PythonPreset<PythonOptions> = {
Expand Down
4 changes: 2 additions & 2 deletions test/generators/python/PythonDependencyManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ describe('PythonDependencyManager', () => {
test('should render unique dependency', () => {
const dependencyManager = new PythonDependencyManager(
PythonGenerator.defaultOptions,
['from x import y', 'from x import y2']
['from x import y', 'from x import y2', 'from x import y_2']
);
expect(dependencyManager.renderDependencies()).toEqual([
'from x import y, y2'
'from x import y, y2, y_2'
]);
});
test('should render __future__ dependency first', () => {
Expand Down
124 changes: 124 additions & 0 deletions test/generators/python/presets/__snapshots__/Pydantic.spec.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,37 @@ exports[`PYTHON_PYDANTIC_PRESET should render pydantic for class 1`] = `
line
description''', default=None)
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None, alias='''additionalProperties''')
@model_serializer(mode='wrap')
def custom_serializer(self, handler):
serialized_self = handler(self)
additional_properties = getattr(self, \\"additional_properties\\")
if additional_properties is not None:
for key, value in additional_properties.items():
# Never overwrite existing values, to avoid clashes
if not hasattr(serialized_self, key):
serialized_self[key] = value
return serialized_self
@model_validator(mode='before')
@classmethod
def unwrap_additional_properties(cls, data):
json_properties = list(data.keys())
known_object_properties = ['prop', 'additional_properties']
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
# Ignore attempts that validate regular models, only when unknown input is used we add unwrap extensions
if len(unknown_object_properties) == 0:
return data
known_json_properties = ['prop', 'additionalProperties']
additional_properties = {}
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
additional_properties[obj_key] = data.pop(obj_key, None)
data['additionalProperties'] = additional_properties
return data
"
`;

Expand All @@ -15,14 +46,107 @@ Array [
"class UnionTest(BaseModel):
union_test: Optional[Union[Union1.Union1, Union2.Union2]] = Field(default=None, alias='''unionTest''')
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None, alias='''additionalProperties''')
@model_serializer(mode='wrap')
def custom_serializer(self, handler):
serialized_self = handler(self)
additional_properties = getattr(self, \\"additional_properties\\")
if additional_properties is not None:
for key, value in additional_properties.items():
# Never overwrite existing values, to avoid clashes
if not hasattr(serialized_self, key):
serialized_self[key] = value
return serialized_self
@model_validator(mode='before')
@classmethod
def unwrap_additional_properties(cls, data):
json_properties = list(data.keys())
known_object_properties = ['union_test', 'additional_properties']
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
# Ignore attempts that validate regular models, only when unknown input is used we add unwrap extensions
if len(unknown_object_properties) == 0:
return data
known_json_properties = ['unionTest', 'additionalProperties']
additional_properties = {}
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
additional_properties[obj_key] = data.pop(obj_key, None)
data['additionalProperties'] = additional_properties
return data
",
"class Union1(BaseModel):
test_prop1: Optional[str] = Field(default=None, alias='''testProp1''')
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None, alias='''additionalProperties''')
@model_serializer(mode='wrap')
def custom_serializer(self, handler):
serialized_self = handler(self)
additional_properties = getattr(self, \\"additional_properties\\")
if additional_properties is not None:
for key, value in additional_properties.items():
# Never overwrite existing values, to avoid clashes
if not hasattr(serialized_self, key):
serialized_self[key] = value
return serialized_self
@model_validator(mode='before')
@classmethod
def unwrap_additional_properties(cls, data):
json_properties = list(data.keys())
known_object_properties = ['test_prop1', 'additional_properties']
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
# Ignore attempts that validate regular models, only when unknown input is used we add unwrap extensions
if len(unknown_object_properties) == 0:
return data
known_json_properties = ['testProp1', 'additionalProperties']
additional_properties = {}
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
additional_properties[obj_key] = data.pop(obj_key, None)
data['additionalProperties'] = additional_properties
return data
",
"class Union2(BaseModel):
test_prop2: Optional[str] = Field(default=None, alias='''testProp2''')
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None, alias='''additionalProperties''')
@model_serializer(mode='wrap')
def custom_serializer(self, handler):
serialized_self = handler(self)
additional_properties = getattr(self, \\"additional_properties\\")
if additional_properties is not None:
for key, value in additional_properties.items():
# Never overwrite existing values, to avoid clashes
if not hasattr(serialized_self, key):
serialized_self[key] = value
return serialized_self
@model_validator(mode='before')
@classmethod
def unwrap_additional_properties(cls, data):
json_properties = list(data.keys())
known_object_properties = ['test_prop2', 'additional_properties']
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
# Ignore attempts that validate regular models, only when unknown input is used we add unwrap extensions
if len(unknown_object_properties) == 0:
return data
known_json_properties = ['testProp2', 'additionalProperties']
additional_properties = {}
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
additional_properties[obj_key] = data.pop(obj_key, None)
data['additionalProperties'] = additional_properties
return data
",
]
`;

0 comments on commit 62fd08b

Please sign in to comment.