Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/test_graph_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Tests for graph save/load round-trip, including per-node graph_id."""

import json
from pathlib import Path

from worldgraph.graph import Graph, Node, Edge, load_graph, save_graph


def test_save_load_roundtrip_single_graph(tmp_path: Path):
"""Single-article graph: graph_id is always serialized per node."""
g = Graph(id="article-1")
n1 = g.add_entity("Alice")
n2 = g.add_entity("Bob")
g.add_edge(n1, n2, "knows")

path = tmp_path / "g.json"
save_graph(g, path)

# graph_id should appear on every node
with open(path) as f:
data = json.load(f)
for node_data in data["nodes"]:
assert node_data["graph_id"] == "article-1"

loaded = load_graph(path)
for node in loaded.nodes.values():
assert node.graph_id == "article-1"


def test_save_load_roundtrip_unified_graph(tmp_path: Path):
"""Unified graph with nodes from different source graphs preserves graph_id."""
g = Graph(id="unified")
# Manually add nodes with different source graph_ids
g.nodes["n1"] = Node(id="n1", graph_id="article-1", name="Alice")
g.nodes["n2"] = Node(id="n2", graph_id="article-2", name="Bob")
g.nodes["n3"] = Node(id="n3", graph_id="unified", name="Carol")
g.edges.append(Edge(source="n1", target="n2", relation="knows"))

path = tmp_path / "unified.json"
save_graph(g, path)

# Every node should have graph_id serialized
with open(path) as f:
data = json.load(f)
nodes_by_id = {n["id"]: n for n in data["nodes"]}
assert nodes_by_id["n1"]["graph_id"] == "article-1"
assert nodes_by_id["n2"]["graph_id"] == "article-2"
assert nodes_by_id["n3"]["graph_id"] == "unified"

# Round-trip preserves per-node graph_id
loaded = load_graph(path)
assert loaded.nodes["n1"].graph_id == "article-1"
assert loaded.nodes["n2"].graph_id == "article-2"
assert loaded.nodes["n3"].graph_id == "unified"
8 changes: 6 additions & 2 deletions worldgraph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def load_graph(path: Path) -> Graph:

for node_data in data["nodes"]:
node_id = node_data["id"]
nodes[node_id] = Node(id=node_id, graph_id=graph_id, name=node_data["name"])
nodes[node_id] = Node(
id=node_id,
graph_id=node_data["graph_id"],
name=node_data["name"],
)

edges: list[Edge] = []
for edge_data in data["edges"]:
Expand All @@ -70,7 +74,7 @@ def save_graph(
"""Write graph to JSON, with optional match groups."""
nodes_out = []
for node in graph.nodes.values():
nodes_out.append({"id": node.id, "name": node.name})
nodes_out.append({"id": node.id, "graph_id": node.graph_id, "name": node.name})

edges_out = []
for edge in graph.edges:
Expand Down