-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
183 lines (152 loc) · 7.03 KB
/
main.py
File metadata and controls
183 lines (152 loc) · 7.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import asyncio
import json
import os
import importlib
import inspect
from agents.clientpool import MultiKeyClientPool
from agents.call_agent import GenAgent
from core.processor import KnowledgeProcessor
from core.judge import DomainJudge
from core.evaluator import Evaluator
from core.cleaner import KnowledgeCleaner
from pipelines.base import BasePipeline
import argparse
async def main():
# 0. Parse arguments
parser = argparse.ArgumentParser(description="LLM Knowledge Extraction Evaluation Framework")
parser.add_argument("--query", type=str, default="Linear Algebra", help="Domain to extract knowledge from")
parser.add_argument("--is_code", action="store_true", help="Whether the domain is code-related")
args = parser.parse_args()
# 1. Setup
domain_query = args.query
is_code_domain = args.is_code
query_id = domain_query.lower().replace(" ", "_")
output_dir = os.path.join("results", query_id)
os.makedirs(output_dir, exist_ok=True)
# 2. Load configuration
if not os.path.exists("api.json"):
print("Error: api.json not found.")
return
with open("api.json") as f:
api_data = json.load(f)
api_keys = api_data["api_keys"]
# Initialize shared components
client_pool = MultiKeyClientPool(api_keys=api_keys)
# 3. Determine Embedding Strategy
config_key = "code_embed_config" if is_code_domain else "embed_config"
embed_config = api_data.get(config_key, api_data.get("embed_config", {}))
if embed_config:
print(f"Using {'CODE' if is_code_domain else 'TEXT'} embedding model: {embed_config.get('model')} at {embed_config.get('base_url')}")
target_keys = embed_config.get("api_keys", api_keys)
embed_client_pool = MultiKeyClientPool(
api_keys=target_keys,
base_url=embed_config.get("base_url")
)
embed_model = embed_config.get("model")
threshold = embed_config.get("threshold", 0.92)
candidate_threshold = embed_config.get("candidate_threshold", 0.82)
else:
# Fallback to defaults
embed_client_pool = client_pool
embed_model = "nvidia/nv-embedcode-7b-v1" if is_code_domain else "nvidia/nv-embed-v1"
threshold = 0.92
candidate_threshold = 0.82
gen_agent = GenAgent(api_key=api_keys)
processor = KnowledgeProcessor(
client_pool=client_pool,
embed_client_pool=embed_client_pool,
embed_model=embed_model,
threshold=threshold,
candidate_threshold=candidate_threshold
)
judge = DomainJudge(client_pool=client_pool)
print(f"Starting extraction for domain: {domain_query} (Type: {'Code' if is_code_domain else 'Text'})")
# 4. Pipeline discovery
active_pipelines = []
target_files = [
"p2_sequential.py",
"p3_reflection.py",
"p4_taxonomy_explorer.py",
"p5_debate"
]
for filename in target_files:
if not filename.endswith(".py"):
filename += ".py"
module_name = f"pipelines.{filename[:-3]}"
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BasePipeline) and obj != BasePipeline:
active_pipelines.append((name, obj(gen_agent, processor)))
except Exception as e:
print(f"Error loading {filename}: {e}")
# --- PHASE 1: SATURATION GENERATION ---
print("\n=== PHASE 1: GENERATION (Saturation Mode) ===")
pipeline_raw_outputs = {} # {pipeline_name: [raw_points]}
for name, pipeline in active_pipelines:
print(f"Running pipeline: {name}...")
try:
run_result = await pipeline.run(domain_query)
raw_points = run_result["points"]
total_tokens = run_result["total_tokens"]
pipeline_raw_outputs[name] = raw_points
# Save raw points immediately
with open(os.path.join(output_dir, f"{name}_raw.json"), "w") as f:
json.dump(raw_points, f, indent=2)
# Save tokens
with open(os.path.join(output_dir, f"{name}.tokens.json"), "w") as f:
json.dump({"total_tokens": total_tokens}, f, indent=2)
print(f" {name} finished. Total raw points: {len(raw_points)} (Cost: {total_tokens} tokens)")
except Exception as e:
print(f" Error running {name}: {e}")
if not pipeline_raw_outputs:
print("No outputs generated. Exiting.")
return
# --- PHASE 2: GLOBAL EVALUATION ---
print("\n=== PHASE 2: EVALUATION (Post-Audit) ===")
# 1. Build Global Union Set
print("Building global deduplicated union set...")
await processor.build_union_set(output_dir)
# 2. Global Domain Audit
print(f"Auditing {len(processor.union_set)} unique knowledge nodes...")
unique_texts = [node["representative_text"] for node in processor.union_set]
audit_results = await judge.check_batch(domain_query, unique_texts)
# Apply audit results
valid_nodes_count = 0
for node, is_valid in zip(processor.union_set, audit_results):
node["is_in_domain"] = is_valid
if is_valid:
valid_nodes_count += 1
print(f"Audit complete: {valid_nodes_count} valid / {len(processor.union_set)} total nodes.")
# Save the audited union set
processor.save_union_set(os.path.join(output_dir, "union_set.json"))
# 3. Metric Calculation
total_valid_nodes = valid_nodes_count
pipeline_metrics = {}
for name, raw_points in pipeline_raw_outputs.items():
covered_indices = processor.get_pipeline_coverage(name)
valid_covered_indices = [i for i in covered_indices if processor.union_set[i]["is_in_domain"]]
recall = len(valid_covered_indices) / total_valid_nodes if total_valid_nodes > 0 else 0
valid_raw_count = 0
for node_idx in covered_indices:
node = processor.union_set[node_idx]
if node["is_in_domain"]:
for source in node["source_entries"]:
if source["pipeline"] == name:
valid_raw_count += 1
accuracy = valid_raw_count / len(raw_points) if raw_points else 0
pipeline_metrics[name] = {"recall": recall, "accuracy": accuracy, "raw_total": len(raw_points)}
# --- FINAL LEADERBOARD ---
print("\n" + "="*75)
print(f"LEADERBOARD (Saturation Extraction) for Domain: {domain_query}")
print("="*75)
print(f"{'Pipeline':30} | {'Recall':10} | {'Accuracy':10} | {'Raw Points'}")
print("-" * 75)
sorted_keys = sorted(pipeline_metrics.keys(), key=lambda x: pipeline_metrics[x]['recall'], reverse=True)
for name in sorted_keys:
m = pipeline_metrics[name]
print(f"{name:30} | {m['recall']:7.2%} | {m['accuracy']:7.2%} | {m['raw_total']}")
print("="*75)
print(f"Global Pseudo Ground Truth Size: {total_valid_nodes} unique points.")
if __name__ == "__main__":
asyncio.run(main())