Skip to content

Commit a9cd04c

Browse files
committed
fix: a bug where train.jsonl does not exist
Signed-off-by: Hung-Han (Henry) Chen <[email protected]>
1 parent e372610 commit a9cd04c

File tree

3 files changed

+20
-17
lines changed

3 files changed

+20
-17
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "mlx-training-rs"
3-
version = "0.2.3"
3+
version = "0.2.4"
44
edition = "2021"
55
repository = "https://github.com/chenhunghan/mlx-training-rs"
66

src/main.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,17 @@ use mlx_training_rs::cli::CLI;
77
use serde::Deserialize;
88
use tokio::fs::{self, OpenOptions};
99
use tokio::io::AsyncWriteExt;
10-
use tokio::runtime::Runtime;
1110
use serde_json;
1211
use async_openai::{Client, types::{CreateChatCompletionRequestArgs, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs}};
1312

14-
fn main() {
15-
let rt = Runtime::new().unwrap();
16-
rt.block_on(main_async()).unwrap();
17-
}
18-
19-
async fn main_async() -> Result<(), Box<dyn Error>> {
13+
#[tokio::main]
14+
async fn main() -> Result<(), Box<dyn Error>> {
2015
// Parse command line arguments
2116
let cli = CLI::parse();
2217
let topic = &cli.topic;
2318
let n = cli.n;
2419

20+
tokio::fs::create_dir_all("./data").await?;
2521
write_instruction_jsonl(topic, n).await?;
2622
write_train_jsonl().await?;
2723
create_valid_file().await?;
@@ -55,6 +51,12 @@ async fn write_instruction_jsonl(topic: &str, n: usize) -> Result<(), Box<dyn Er
5551
let instructions = fs::read_to_string(&file_path).await?;
5652
let instructions: Vec<Instruction> = instructions.lines().map(|line| serde_json::from_str(&line).unwrap()).collect();
5753

54+
// Open the file in append mode
55+
let mut file = OpenOptions::new()
56+
.append(true)
57+
.open(&file_path)
58+
.await?;
59+
5860
println!("------------------------------");
5961
println!("{}", format!("Generating instructions on topic {}...", topic));
6062
for _ in 0..n {
@@ -64,12 +66,6 @@ async fn write_instruction_jsonl(topic: &str, n: usize) -> Result<(), Box<dyn Er
6466
// println!("Skipping duplicate instruction: {}", instruction);
6567
continue;
6668
} else {
67-
// Open the file in append mode
68-
let mut file = OpenOptions::new()
69-
.create(true)
70-
.append(true)
71-
.open(&file_path)
72-
.await?;
7369

7470
println!("------------------------------");
7571
println!("Writing new instruction to file: {}", instruction);
@@ -107,9 +103,16 @@ async fn write_train_jsonl() -> Result<(), Box<dyn Error>> {
107103
let total = instructions.len();
108104

109105
let train_file_path = PathBuf::from("./data/").join("train.jsonl");
106+
if !train_file_path.exists() {
107+
println!("Creating train.jsonl file...");
108+
let _ = OpenOptions::new()
109+
.create(true)
110+
.append(true)
111+
.open(&train_file_path)
112+
.await?;
113+
}
110114
let trainings: Vec<Train> = fs::read_to_string(&train_file_path).await?.lines().filter_map(|line| serde_json::from_str(&line).ok()).collect();
111-
print!("{} data found in train.jsonl. ", trainings.len());
112-
115+
113116
for (i, instruction) in instructions.iter().enumerate() {
114117
if let Some(_) = trainings.iter().find(|t| t.text.contains(&instruction.text)) {
115118
// println!("Skipping processing instruction {} because it can be found in train.jsonl", instruction.text);

0 commit comments

Comments
 (0)