Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 0 additions & 45 deletions .github/workflows/homebrew-bump.yml

This file was deleted.

26 changes: 26 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Test

on:
push:
branches:
- main
- master
pull_request:

jobs:
rust-tests:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Setup Rust
uses: dtolnay/rust-toolchain@stable

- name: Cache cargo
uses: Swatinem/rust-cache@v2


- name: Run unit tests
run: cargo test
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "corgea"
version = "1.8.0"
version = "1.8.1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
224 changes: 204 additions & 20 deletions src/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,32 +137,36 @@ async fn start_callback_server(
};

loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let auth_code_clone = auth_code.clone();

let service = service_fn(move |req| {
handle_callback(req, auth_code_clone.clone())
});

tokio::task::spawn(async move {
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
{
eprintln!("Error serving connection: {:?}", err);
tokio::select! {
accept_result = listener.accept() => {
let (stream, _) = accept_result?;
let io = TokioIo::new(stream);
let auth_code_clone = auth_code.clone();

let service = service_fn(move |req| {
handle_callback(req, auth_code_clone.clone())
});

tokio::task::spawn(async move {
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
{
eprintln!("Error serving connection: {:?}", err);
}
});
}
});

// Check if we got the code
_ = tokio::time::sleep(Duration::from_millis(100)) => {}
}

// Check if we got the code.
// We must do this outside of `accept()` blocking so we don't miss a code
// that was set by the request task after a single callback request.
if let Ok(code_guard) = auth_code.lock() {
if let Some(code) = code_guard.as_ref() {
return Ok(code.clone());
}
}

// Add a small delay to prevent busy waiting
tokio::time::sleep(Duration::from_millis(100)).await;
}
}

Expand Down Expand Up @@ -523,3 +527,183 @@ fn parse_query_params(query: &str) -> HashMap<String, String> {
}



#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
use std::net::{TcpListener as StdTcpListener, TcpStream};
use std::sync::mpsc;
use std::thread;
use std::time::Duration as StdDuration;
use tokio::runtime::Runtime;
use tokio::time::{timeout, Duration};

fn reserve_ephemeral_port() -> u16 {
let listener = StdTcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port");
listener.local_addr().expect("failed to get local addr").port()
}

fn spawn_callback_server(
port: u16,
auth_code: Arc<Mutex<Option<String>>>,
) -> mpsc::Receiver<Result<String, String>> {
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let runtime = Runtime::new().expect("failed to create tokio runtime");
let result = runtime
.block_on(start_callback_server(port, auth_code))
.map_err(|e| e.to_string());
tx.send(result).expect("failed to send callback result");
});

rx
}

fn send_http_get(port: u16, path: &str) -> (u16, String) {
let mut stream = None;

for _ in 0..50 {
match TcpStream::connect(("127.0.0.1", port)) {
Ok(s) => {
stream = Some(s);
break;
}
Err(_) => thread::sleep(StdDuration::from_millis(20)),
}
}

let mut stream = stream.expect("failed to connect to callback server");
let request = format!("GET {} HTTP/1.0\r\n\r\n", path);

stream
.write_all(request.as_bytes())
.expect("failed to write request");

let mut raw_response = String::new();
stream
.read_to_string(&mut raw_response)
.expect("failed to read response");

let mut sections = raw_response.splitn(2, "\r\n\r\n");
let headers = sections.next().expect("response headers missing");
let body = sections.next().unwrap_or_default().to_string();
let status_line = headers.lines().next().expect("status line missing");
let status = status_line
.split_whitespace()
.nth(1)
.expect("status code missing")
.parse::<u16>()
.expect("invalid status code");

(status, body)
}

#[test]
fn parse_query_params_decodes_values() {
let params = parse_query_params("code=a%20b&error_description=needs%2Blogin");

assert_eq!(params.get("code"), Some(&"a b".to_string()));
assert_eq!(params.get("error_description"), Some(&"needs+login".to_string()));
}

#[test]
fn parse_query_params_ignores_malformed_pairs() {
let params = parse_query_params("valid=ok&invalid&also_invalid=");

assert_eq!(params.get("valid"), Some(&"ok".to_string()));
assert_eq!(params.get("invalid"), None);
assert_eq!(params.get("also_invalid"), Some(&"".to_string()));
}

#[test]
fn port_is_available_reflects_current_port_usage() {
let listener = StdTcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port");
let port = listener
.local_addr()
.expect("failed to get listener addr")
.port();

assert!(!port_is_available(port));
drop(listener);
assert!(port_is_available(port));
}

#[test]
fn find_available_port_skips_ports_that_are_in_use() {
let listener = StdTcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port");
let occupied_port = listener
.local_addr()
.expect("failed to get listener addr")
.port();

let found_port = find_available_port(occupied_port).expect("should find an available port");

assert_ne!(found_port, occupied_port);
}

#[tokio::test]
async fn start_callback_server_returns_without_waiting_for_second_connection() {
let port = reserve_ephemeral_port();
let auth_code = Arc::new(Mutex::new(Some("test-code".to_string())));

let returned_code = timeout(
Duration::from_millis(300),
start_callback_server(port, auth_code),
)
.await
.expect("callback server timed out")
.expect("callback server should return code");

assert_eq!(returned_code, "test-code");
}

#[test]
fn start_callback_server_returns_bind_error_if_port_is_occupied() {
let listener = StdTcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port");
let occupied_port = listener
.local_addr()
.expect("failed to get listener addr")
.port();

let runtime = Runtime::new().expect("failed to create runtime");
let result = runtime.block_on(start_callback_server(
occupied_port,
Arc::new(Mutex::new(None::<String>)),
));

assert!(result.is_err());
let error = result.err().expect("expected bind error").to_string();
assert!(error.contains("Failed to bind"));
}

#[test]
fn callback_server_serves_waiting_error_and_success_pages_then_returns_code() {
let port = reserve_ephemeral_port();
let auth_code = Arc::new(Mutex::new(None::<String>));
let result_rx = spawn_callback_server(port, auth_code);

let (waiting_status, waiting_body) = send_http_get(port, "/");
assert_eq!(waiting_status, 200);
assert!(waiting_body.contains("Waiting for Authorization"));

let (error_status, error_body) = send_http_get(
port,
"/?error=access_denied&error_description=user%20cancelled",
);
assert_eq!(error_status, 400);
assert!(error_body.contains("Authorization Failed"));
assert!(error_body.contains("access_denied"));

let (success_status, success_body) = send_http_get(port, "/?code=abc123");
assert_eq!(success_status, 200);
assert!(success_body.contains("Successfully Signed In"));

let returned_code = result_rx
.recv_timeout(StdDuration::from_secs(2))
.expect("callback server should return in time")
.expect("callback server should return code");

assert_eq!(returned_code, "abc123");
}
}
Loading