Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,26 @@ def extract_property_graph_by_llm(self, schema, chunk):
return self.llm.generate(prompt=prompt)

def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]:
# Use regex to extract a JSON object with curly braces
json_match = re.search(r"({.*})", text, re.DOTALL)
# Strip markdown code blocks (e.g. ```json ... ```)
text = re.sub(r"```\w*\n?", "", text)
text = re.sub(r"```", "", text)
text = text.strip()

# Try to extract JSON (object or array)
json_match = re.search(r"(\{.*\}|\[.*\])", text, re.DOTALL)
if not json_match:
log.critical(
"Invalid property graph! No JSON object found, please check the output format example in prompt."
)
log.critical("Invalid property graph! No JSON found, please check the output format example in prompt.")
return []
json_str = json_match.group(1).strip()

items = []
try:
property_graph = json.loads(json_str)
# Handle flat array format: convert to {"vertices": [...], "edges": [...]}
if isinstance(property_graph, list):
vertices = [item for item in property_graph if isinstance(item, dict) and item.get("type") == "vertex"]
edges = [item for item in property_graph if isinstance(item, dict) and item.get("type") == "edge"]
property_graph = {"vertices": vertices, "edges": edges}
# Expect property_graph to be a dict with keys "vertices" and "edges"
if not (isinstance(property_graph, dict) and "vertices" in property_graph and "edges" in property_graph):
log.critical("Invalid property graph format; expecting 'vertices' and 'edges'.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,160 @@ def test_extract_and_filter_label_valid_json(self):
self.assertEqual(result[1]["type"], "edge")
self.assertEqual(result[1]["label"], "acted_in")

def test_extract_and_filter_label_markdown_json(self):
"""Test _extract_and_filter_label with JSON wrapped in markdown fences."""
extractor = PropertyGraphExtract(llm=self.mock_llm)
text = f"""```json
{self.llm_responses[1]}
```"""

result = extractor._extract_and_filter_label(self.schema, text)

self.assertEqual(len(result), 2)
self.assertEqual(result[0]["type"], "vertex")
self.assertEqual(result[0]["label"], "movie")
self.assertEqual(result[1]["type"], "edge")
self.assertEqual(result[1]["label"], "acted_in")

def test_extract_and_filter_label_markdown_json_with_prose(self):
"""Test fenced JSON can be parsed when the LLM adds prose."""
extractor = PropertyGraphExtract(llm=self.mock_llm)
text = f"""Here is the extracted graph:
```
{self.llm_responses[1]}
```
Hope this helps."""

result = extractor._extract_and_filter_label(self.schema, text)

self.assertEqual(len(result), 2)
self.assertEqual(result[0]["type"], "vertex")
self.assertEqual(result[0]["label"], "movie")
self.assertEqual(result[1]["type"], "edge")
self.assertEqual(result[1]["label"], "acted_in")

def test_extract_and_filter_label_flat_array_json(self):
"""Test _extract_and_filter_label converts flat arrays to vertices and edges."""
extractor = PropertyGraphExtract(llm=self.mock_llm)
text = """```json
[
{
"type": "vertex",
"label": "person",
"properties": {
"name": "Tom Hanks"
}
},
{
"type": "edge",
"label": "acted_in",
"properties": {
"role": "Forrest Gump"
},
"source": {
"label": "person",
"properties": {
"name": "Tom Hanks"
}
},
"target": {
"label": "movie",
"properties": {
"title": "Forrest Gump"
}
}
}
]
```"""

result = extractor._extract_and_filter_label(self.schema, text)

self.assertEqual(len(result), 2)
self.assertEqual(result[0]["type"], "vertex")
self.assertEqual(result[0]["label"], "person")
self.assertEqual(result[1]["type"], "edge")
self.assertEqual(result[1]["label"], "acted_in")

def test_extract_and_filter_label_flat_array_filters_invalid_items(self):
"""Test flat arrays keep valid graph items and drop invalid ones."""
extractor = PropertyGraphExtract(llm=self.mock_llm)
text = """[
{
"type": "vertex",
"label": "person",
"properties": {
"name": "Tom Hanks"
}
},
{
"type": "vertex",
"label": "unknown_label",
"properties": {
"name": "Unknown"
}
},
{
"type": "edge",
"label": "acted_in",
"properties": {
"role": "Forrest Gump"
},
"source": {
"label": "person",
"properties": {
"name": "Tom Hanks"
}
},
"target": {
"label": "movie",
"properties": {
"title": "Forrest Gump"
}
}
},
{
"type": "edge",
"label": "unknown_edge",
"properties": {}
},
{
"type": "note",
"label": "person",
"properties": {}
},
"not-a-dict"
]"""

result = extractor._extract_and_filter_label(self.schema, text)

self.assertEqual(len(result), 2)
self.assertEqual(result[0]["type"], "vertex")
self.assertEqual(result[0]["label"], "person")
self.assertEqual(result[1]["type"], "edge")
self.assertEqual(result[1]["label"], "acted_in")

def test_extract_and_filter_label_malformed_fenced_json(self):
"""Test malformed fenced JSON returns no graph items."""
extractor = PropertyGraphExtract(llm=self.mock_llm)
text = """```json
{
"vertices": [
{
"type": "vertex",
"label": "person",
"properties": {
"name": "Tom Hanks"
}
}
],
"edges": []
```
"""

result = extractor._extract_and_filter_label(self.schema, text)

self.assertEqual(result, [])

def test_extract_and_filter_label_invalid_json(self):
"""Test the _extract_and_filter_label method with invalid JSON."""
extractor = PropertyGraphExtract(llm=self.mock_llm)
Expand Down
Loading