Skip to content

Commit 68018ca

Browse files
committed
feat: add readme and test files
1 parent 2482fe3 commit 68018ca

File tree

4 files changed

+506
-1
lines changed

4 files changed

+506
-1
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,7 @@ CLAUDE.md
187187

188188
test*.py
189189

190-
datasets/
190+
datasets/
191+
192+
!test_pipeline.py
193+
!test_convert_pipeline.py

rdagent/components/data/README.md

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# RD-Agent Data 组件
2+
3+
数据集处理 pipeline,从搜索到 SFT 格式转换的完整流程。
4+
5+
## 两阶段处理流程
6+
7+
### Phase 1: 数据采集(test_pipeline.py)
8+
9+
**流程**:搜索 → 下载 → 检查 → LLM过滤 → 选择性迁移
10+
11+
**涉及模块**
12+
- `search_api.py`:HuggingFace API 封装,支持 3 维搜索(domain/size/language)
13+
- `dataset_agent.py`:LLM 驱动的搜索代理,自动选择最佳数据集
14+
- `dataset_inspector.py`:数据集检查器,LLM 分析哪些文件有用
15+
- `dataset_manager.py`:存储管理,选择性迁移有用文件到 `./datasets/raw/`
16+
17+
**运行命令**
18+
```bash
19+
python test_pipeline.py
20+
```
21+
22+
**输出**
23+
- 数据集下载到临时目录 `/tmp/dataset_staging`
24+
- 有用文件迁移到 `./datasets/raw/`
25+
- 自动过滤垃圾文件,节省存储空间
26+
27+
---
28+
29+
### Phase 2: SFT 转换(test_convert_pipeline.py)
30+
31+
**流程**:加载数据 → Schema分析 → 智能路由 → 转换为 Alpaca 格式
32+
33+
**涉及模块**
34+
- `schema_analyzer.py`:LLM 分析数据 schema,识别 instruction/output 列
35+
- `data_converter.py`:转换为 Alpaca 格式,支持单轮/多轮对话
36+
- `data_cleaner.py`:数据清洗(去重、长度过滤、LLM质量打分)
37+
- `sft_processor.py`:主流程编排,智能路由(Light Path/Heavy Path)
38+
39+
**智能路由**
40+
- **Light Path**:数据质量 >0.8 → 简单转换 + 清洗
41+
- **Heavy Path**:数据复杂 → 直接 LLM 批量转换
42+
43+
**运行命令**
44+
```bash
45+
python test_convert_pipeline.py # 需要先运行 test_pipeline.py
46+
```
47+
48+
**输出**
49+
- Alpaca JSON 格式文件保存到 `./datasets/sft/`
50+
- 包含 instruction/input/output 字段
51+
- 经过去重和质量过滤(≥7.0分)
52+
53+
## 文件详细说明
54+
55+
### Phase 1: 数据采集相关文件
56+
57+
#### search_api.py(135行)
58+
- **核心类**`HuggingFaceSearchAPI`
59+
- **主要功能**:封装 HuggingFace Hub API,提供数据集搜索能力
60+
- **关键方法**
61+
- `search_datasets()`:支持 domain(模糊匹配)、size、language 三维搜索
62+
- `get_dataset_info()`:获取单个数据集的详细信息
63+
- **特点**:自动过滤需要申请权限的 gated datasets,返回结构化搜索结果
64+
65+
#### dataset_agent.py(499行)
66+
- **核心类**`DatasetSearchAgent`
67+
- **主要功能**:LLM 驱动的智能搜索代理,自动生成搜索参数并选择最佳数据集
68+
- **关键方法**
69+
- `search_and_download()`:完整流程(搜索→选择→下载)
70+
- `_generate_search_params()`:LLM 根据任务描述生成搜索参数
71+
- `_select_best_dataset()`:LLM 基于 4 维评估选择最佳数据集
72+
- `_apply_license_blacklist()`:过滤 NC/ND/GPL 等限制性 license
73+
- **特点**:混合重试策略(第1次 LLM 智能调整,后续规则式放松参数)
74+
75+
#### dataset_inspector.py(658行)
76+
- **核心类**`DatasetInspector`
77+
- **主要功能**:数据集质量检查和文件分析
78+
- **关键方法**
79+
- `inspect()`:加载数据集并提取结构信息(列名、样本数、数据类型等)
80+
- `check_quality()`:规则式质量检查(不依赖 LLM)
81+
- `analyze_files_for_sft()`:LLM 分析哪些文件对 SFT 训练有用
82+
- `_preview_xxx_file()`:支持 csv/json/parquet 等格式的文件预览
83+
- **特点**:智能文件分类,自动识别并过滤垃圾文件,节省存储空间
84+
85+
#### dataset_manager.py(109行)
86+
- **核心类**`DatasetManager`
87+
- **主要功能**:数据集存储和迁移管理
88+
- **关键方法**
89+
- `migrate_dataset_selective()`:基于文件分析结果,只迁移有用文件
90+
- **特点**:组织化存储结构(raw/ 和 converted/ 分离),自动创建目录
91+
92+
### Phase 2: SFT 转换相关文件
93+
94+
#### schema_analyzer.py
95+
- **核心类**`SchemaAnalyzer`
96+
- **主要功能**:LLM 分析数据集的 schema 结构
97+
- **关键方法**
98+
- `analyze()`:识别 instruction/input/output 列,判断单轮/多轮对话
99+
- `_validate_schema_result()`:验证 LLM 输出格式是否正确
100+
- **返回格式**:包含 data_type、instruction_col、output_col、input_col、reasoning
101+
- **特点**:3 次重试机制,失败时有启发式 fallback
102+
103+
#### data_converter.py
104+
- **核心类**`DataConverter`
105+
- **主要功能**:将各种格式数据转换为标准 Alpaca 格式
106+
- **关键方法**
107+
- `convert_to_alpaca()`:主转换入口
108+
- `_convert_single_turn()`:单轮 QA 转换逻辑
109+
- `_convert_multi_turn()`:多轮对话转换,保留历史作为 context
110+
- `_extract_metadata()`:智能提取元数据(白名单优先,黑名单排除)
111+
- **支持格式**:csv、json、jsonl、parquet、arrow
112+
113+
#### data_cleaner.py
114+
- **核心类**`DataCleaner`
115+
- **主要功能**:数据清洗和质量过滤
116+
- **清洗流程**
117+
1. 去重:基于 instruction+output 的 MD5 哈希
118+
2. 长度过滤:设置最小/最大长度阈值
119+
3. 质量打分:LLM 批量评分(10条/批),保留 ≥7.0 分
120+
- **特点**:20 workers 并行处理,采样策略(超过 10000 条只评分前 10000)
121+
122+
#### sft_processor.py
123+
- **核心类**`SFTProcessor``CheckpointManager`
124+
- **主要功能**:生产级 SFT 数据准备系统,完整 pipeline 编排
125+
- **智能路由**
126+
- Light Path(质量>0.8):schema分析 → 简单转换 → 清洗
127+
- Heavy Path(质量≤0.8):直接 LLM 批量转换
128+
- **关键特性**
129+
- 断点续传:batch 级别 checkpoint,中断可恢复
130+
- 并行处理:20 workers 同时处理
131+
- 增量保存:每完成 1 个 batch 立即保存
132+
- **特点**:整合所有上述模块,提供统一入口
133+
134+
### 辅助文件
135+
136+
#### prompts.yaml
137+
- **功能**:集中管理所有 LLM 提示词模板
138+
- **包含提示词**
139+
- search_params:生成搜索参数
140+
- dataset_selection:数据集选择评估
141+
- schema_analysis_for_sft:schema 结构分析
142+
- quality_scoring_batch:批量质量打分
143+
- heavy_conversion:Heavy Path 直接转换
144+
- **特点**:使用模板系统渲染,便于维护和更新
145+
146+
#### __init__.py
147+
- **功能**:模块导出和便捷函数
148+
- **导出内容**:所有主要类 + `convert_to_sft()` 一行代码函数
149+
- **便捷函数**:自动完成从搜索到输出的完整流程
150+
151+
## 快速使用
152+
153+
### 方式一:两步运行
154+
```bash
155+
# Phase 1: 数据采集
156+
python test_pipeline.py
157+
158+
# Phase 2: SFT 转换
159+
python test_convert_pipeline.py
160+
```
161+
162+
### 方式二:一行代码
163+
```python
164+
from rdagent.components.data import convert_to_sft
165+
166+
convert_to_sft(
167+
input_path="data/raw/",
168+
output_file="output/alpaca.json",
169+
task_description="数学推理数据集"
170+
)
171+
```
172+
173+
174+
175+
176+
## 依赖关系
177+
178+
```
179+
Phase 1: Phase 2:
180+
dataset_agent → search_api sft_processor
181+
↓ ├── schema_analyzer
182+
dataset_inspector ├── data_converter
183+
↓ └── data_cleaner
184+
dataset_manager ↑
185+
prompts.yaml
186+
```
187+
188+
## Alpaca 输出格式
189+
190+
```json
191+
{
192+
"instruction": "问题或指令",
193+
"input": "输入上下文(可选)",
194+
"output": "回答或输出",
195+
"metadata": {
196+
"category": "分类",
197+
"difficulty": "难度"
198+
}
199+
}
200+
```
201+

test_convert_pipeline.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
Dataset SFT Conversion Pipeline (Phase 2: SFT Conversion).
3+
Workflow: Load migrated dataset → Schema Analysis → Intelligent Routing → Convert to Alpaca format
4+
5+
Prerequisites: Run test_pipeline.py first to migrate dataset to ./datasets/raw/
6+
"""
7+
8+
import json
9+
import os
10+
from pathlib import Path
11+
12+
from rdagent.components.data import convert_to_sft
13+
14+
# Configuration
15+
DATASETS_ROOT = Path("./datasets/raw")
16+
OUTPUT_DIR = Path("./datasets/sft")
17+
TASK_DESCRIPTION = "数学推理数据集" # 需要与 test_pipeline.py 保持一致
18+
19+
print("=" * 70)
20+
print("SFT 转换流程 (Phase 2: 数据转换与清洗)")
21+
print("=" * 70)
22+
print(f"数据集根目录: {DATASETS_ROOT}")
23+
print(f"输出目录: {OUTPUT_DIR}")
24+
print(f"任务描述: {TASK_DESCRIPTION}")
25+
print("=" * 70)
26+
27+
28+
def find_latest_dataset(datasets_root: Path) -> Path:
29+
"""查找最新迁移的数据集目录"""
30+
if not datasets_root.exists():
31+
raise FileNotFoundError(f"数据集根目录不存在: {datasets_root}")
32+
33+
# 获取所有子目录
34+
subdirs = [d for d in datasets_root.iterdir() if d.is_dir()]
35+
36+
if not subdirs:
37+
raise FileNotFoundError(f"未找到任何数据集: {datasets_root}")
38+
39+
# 按修改时间排序,返回最新的
40+
latest_dataset = max(subdirs, key=lambda d: d.stat().st_mtime)
41+
return latest_dataset
42+
43+
44+
def test_sft_conversion():
45+
"""测试 SFT 转换流程(智能分流)"""
46+
47+
# Step 1: 查找最新数据集
48+
print("\n[Step 1/3] 查找迁移后的数据集...")
49+
50+
try:
51+
dataset_path = find_latest_dataset(DATASETS_ROOT)
52+
print(f"✅ 找到数据集: {dataset_path.name}")
53+
print(f" 路径: {dataset_path}")
54+
print(f" 修改时间: {dataset_path.stat().st_mtime}")
55+
except Exception as e:
56+
print(f"❌ 未找到数据集: {e}")
57+
print(f"\n提示: 请先运行 test_pipeline.py 下载并迁移数据集")
58+
return False
59+
60+
# Step 2: 准备输出路径
61+
print("\n[Step 2/3] 准备输出路径...")
62+
63+
output_file = OUTPUT_DIR / f"{dataset_path.name}_alpaca.json"
64+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
65+
66+
print(f"✅ 输出路径准备完成")
67+
print(f" 输出文件: {output_file}")
68+
69+
# Clean checkpoint before conversion
70+
checkpoint_file = Path("sft_checkpoint.json")
71+
if checkpoint_file.exists():
72+
checkpoint_file.unlink()
73+
print(f" 🧹 清理 checkpoint: {checkpoint_file}")
74+
75+
# Step 3: SFT 转换(智能分流)
76+
print("\n[Step 3/3] SFT 转换(智能分流系统)...")
77+
print("=" * 70)
78+
print("智能分流说明:")
79+
print(" - 轻量路径 (Light Path): 标准 Q&A 数据 → 简单转换 + 去重 + 并行质量评分")
80+
print(" - 重度路径 (Heavy Path): 混乱数据 → 去重 + 直接并行 LLM 转换")
81+
print(" - 系统自动根据数据质量选择路径")
82+
print("=" * 70)
83+
84+
try:
85+
result = convert_to_sft(
86+
input_path=str(dataset_path),
87+
output_file=str(output_file),
88+
task_description=TASK_DESCRIPTION,
89+
)
90+
91+
# 验证结果
92+
print("\n✅ 转换完成!")
93+
print("=" * 70)
94+
print("转换统计:")
95+
print(f" 处理路径: {result.get('processing_path', 'unknown').upper()}")
96+
print(f" 成功状态: {result['success']}")
97+
print(f" 输入样本: {result['stats'].get('total_rows', 0)}")
98+
print(f" 输出样本: {result['stats'].get('successful_rows', 0)}")
99+
print(f" 质量分数: {result['stats'].get('quality_score', 0):.2f}")
100+
print("=" * 70)
101+
102+
# 检查输出文件
103+
if output_file.exists():
104+
with open(output_file, "r", encoding="utf-8") as f:
105+
output_data = json.load(f)
106+
107+
print(f"\n📄 输出文件验证:")
108+
print(f" 文件路径: {output_file}")
109+
print(f" 样本总数: {len(output_data)}")
110+
print(f" 文件大小: {output_file.stat().st_size / 1024 / 1024:.2f}MB")
111+
print(f" 格式验证: {'✓' if all('instruction' in s and 'output' in s for s in output_data) else '✗'}")
112+
113+
# 显示示例
114+
if output_data:
115+
print(f"\n📝 示例样本 (前 3 个):")
116+
for i, sample in enumerate(output_data[:3]):
117+
print(f"\n 样本 {i+1}:")
118+
print(f" instruction: {sample['instruction'][:80]}...")
119+
if sample.get("input"):
120+
print(f" input: {sample['input'][:60]}...")
121+
print(f" output: {sample['output'][:80]}...")
122+
if "metadata" in sample:
123+
print(f" metadata: {sample['metadata']}")
124+
125+
# 数据质量统计
126+
if output_data:
127+
avg_instruction_len = sum(len(s["instruction"]) for s in output_data) / len(output_data)
128+
avg_output_len = sum(len(s["output"]) for s in output_data) / len(output_data)
129+
has_metadata = sum(1 for s in output_data if "metadata" in s)
130+
131+
print(f"\n📊 数据质量统计:")
132+
print(f" 平均 instruction 长度: {avg_instruction_len:.0f} 字符")
133+
print(f" 平均 output 长度: {avg_output_len:.0f} 字符")
134+
print(f" 包含 metadata: {has_metadata}/{len(output_data)} 样本")
135+
136+
return result["success"]
137+
138+
except Exception as e:
139+
print(f"❌ 转换失败: {e}")
140+
import traceback
141+
142+
traceback.print_exc()
143+
return False
144+
145+
146+
def main():
147+
"""运行完整的 SFT 转换流程"""
148+
149+
success = test_sft_conversion()
150+
151+
# 总结
152+
print("\n" + "=" * 70)
153+
if success:
154+
print("✅ SFT 转换流程完成!")
155+
print("=" * 70)
156+
print("下一步建议:")
157+
print(" 1. 检查输出文件质量")
158+
print(" 2. 使用输出文件进行 LoRA/SFT 训练")
159+
print(f" 3. 输出文件位置: {OUTPUT_DIR}")
160+
else:
161+
print("❌ SFT 转换失败!")
162+
print("=" * 70)
163+
print("请检查:")
164+
print(" 1. 是否已运行 test_pipeline.py 迁移数据集?")
165+
print(" 2. 数据集格式是否正确?")
166+
print(" 3. LLM API 是否配置正确?")
167+
print("=" * 70)
168+
169+
170+
if __name__ == "__main__":
171+
main()

0 commit comments

Comments
 (0)