Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions agentic_doc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,17 @@ def create_metadata_model(model: type[BaseModel]) -> type[BaseModel]:
"""
Recursively creates a new Pydantic model from an existing one,
replacing all leaf-level field types with MetadataType.
Preserves field aliases to maintain compatibility with API responses.
"""
fields: Dict[str, Any] = {}
for name, field in model.model_fields.items():
field_type = field.annotation

origin = get_origin(field_type)

# Preserve the original field's alias if it exists
alias = field.alias if field.alias else None

# Handle Optional/Union types
if origin is Union:
args = get_args(field_type)
Expand All @@ -101,11 +105,14 @@ def create_metadata_model(model: type[BaseModel]) -> type[BaseModel]:
non_none_type, BaseModel
):
metadata_type = create_metadata_model(non_none_type)
fields[name] = (Optional[metadata_type], Field(default=None))
fields[name] = (
Optional[metadata_type],
Field(default=None, alias=alias),
)
else:
fields[name] = (
Optional[MetadataType[non_none_type]], # type: ignore[valid-type]
Field(default=None),
Field(default=None, alias=alias),
)
continue

Expand All @@ -116,23 +123,23 @@ def create_metadata_model(model: type[BaseModel]) -> type[BaseModel]:
metadata_inner_type = create_metadata_model(inner_type)
fields[name] = (
List[metadata_inner_type], # type: ignore[valid-type]
Field(default_factory=list), # type: ignore[arg-type]
Field(default_factory=list, alias=alias), # type: ignore[arg-type]
)
else:
fields[name] = (
List[MetadataType[inner_type]], # type: ignore[valid-type]
Field(default_factory=list), # type: ignore[arg-type]
Field(default_factory=list, alias=alias), # type: ignore[arg-type]
)
continue

# Handle nested models
if inspect.isclass(field_type) and issubclass(field_type, BaseModel):
fields[name] = (create_metadata_model(field_type), Field())
fields[name] = (create_metadata_model(field_type), Field(alias=alias))
else:
# Replace primitive leaf with MetadataType[original type]
fields[name] = (
MetadataType[field_type], # type: ignore[valid-type]
Field(),
Field(alias=alias),
)

return create_model(f"{model.__name__}Metadata", **fields)
Expand Down